diff --git a/Telegram/SourceFiles/api/api_editing.cpp b/Telegram/SourceFiles/api/api_editing.cpp index f0f0d4cd2..b5a2e1085 100644 --- a/Telegram/SourceFiles/api/api_editing.cpp +++ b/Telegram/SourceFiles/api/api_editing.cpp @@ -157,7 +157,7 @@ void EditMessageWithUploadedMedia( void RescheduleMessage( not_null item, SendOptions options) { - const auto empty = [](const auto &r) {}; + const auto empty = [] {}; EditMessage(item, options, empty, empty); } diff --git a/Telegram/SourceFiles/apiwrap.cpp b/Telegram/SourceFiles/apiwrap.cpp index 6ebd12c68..a5b76284b 100644 --- a/Telegram/SourceFiles/apiwrap.cpp +++ b/Telegram/SourceFiles/apiwrap.cpp @@ -390,7 +390,7 @@ void ApiWrap::acceptTerms(bytes::const_span id) { void ApiWrap::checkChatInvite( const QString &hash, FnMut done, - FnMut fail) { + Fn fail) { request(base::take(_checkInviteRequestId)).cancel(); _checkInviteRequestId = request(MTPmessages_CheckChatInvite( MTP_string(hash) @@ -1227,7 +1227,7 @@ void ApiWrap::requestPeerSettings(not_null peer) { void ApiWrap::migrateChat( not_null chat, FnMut)> done, - FnMut fail) { + Fn fail) { const auto callback = [&] { return MigrateCallbacks{ std::move(done), std::move(fail) }; }; diff --git a/Telegram/SourceFiles/apiwrap.h b/Telegram/SourceFiles/apiwrap.h index b7c75744a..568092eaf 100644 --- a/Telegram/SourceFiles/apiwrap.h +++ b/Telegram/SourceFiles/apiwrap.h @@ -227,7 +227,7 @@ public: void checkChatInvite( const QString &hash, FnMut done, - FnMut fail); + Fn fail); void importChatInvite(const QString &hash); void requestChannelMembersForAdd( @@ -243,7 +243,7 @@ public: void migrateChat( not_null chat, FnMut)> done, - FnMut fail = nullptr); + Fn fail = nullptr); void markMediaRead(const base::flat_set> &items); void markMediaRead(not_null item); @@ -742,11 +742,11 @@ private: mtpRequestId _checkInviteRequestId = 0; FnMut _checkInviteDone; - FnMut _checkInviteFail; + Fn _checkInviteFail; struct MigrateCallbacks { FnMut)> done; - FnMut fail; + Fn fail; }; base::flat_map< not_null, diff --git a/Telegram/SourceFiles/export/export_api_wrap.cpp b/Telegram/SourceFiles/export/export_api_wrap.cpp index b74e4072a..04fb24e77 100644 --- a/Telegram/SourceFiles/export/export_api_wrap.cpp +++ b/Telegram/SourceFiles/export/export_api_wrap.cpp @@ -249,26 +249,26 @@ public: RequestBuilder( Original &&builder, - FnMut commonFailHandler); + Fn commonFailHandler); [[nodiscard]] RequestBuilder &done(FnMut &&handler); [[nodiscard]] RequestBuilder &done( FnMut &&handler); [[nodiscard]] RequestBuilder &fail( - FnMut &&handler); + Fn &&handler); mtpRequestId send(); private: Original _builder; - FnMut _commonFailHandler; + Fn _commonFailHandler; }; template ApiWrap::RequestBuilder::RequestBuilder( Original &&builder, - FnMut commonFailHandler) + Fn commonFailHandler) : _builder(std::move(builder)) , _commonFailHandler(std::move(commonFailHandler)) { } @@ -295,15 +295,15 @@ auto ApiWrap::RequestBuilder::done( template auto ApiWrap::RequestBuilder::fail( - FnMut &&handler + Fn &&handler ) -> RequestBuilder& { if (handler) { auto &silence_warning = _builder.fail([ common = base::take(_commonFailHandler), specific = std::move(handler) - ](RPCError &&error) mutable { + ](const RPCError &error) mutable { if (!specific(error)) { - common(std::move(error)); + common(error); } }); } @@ -364,7 +364,7 @@ auto ApiWrap::mainRequest(Request &&request) { return RequestBuilder>( std::move(original), - [=](RPCError &&result) { error(std::move(result)); }); + [=](const RPCError &result) { error(result); }); } template @@ -391,7 +391,7 @@ auto ApiWrap::fileRequest(const Data::FileLocation &location, int offset) { location.data, MTP_int(offset), MTP_int(kFileChunkSize)) - )).fail([=](RPCError &&result) { + )).fail([=](const RPCError &result) { if (result.type() == qstr("TAKEOUT_FILE_EMPTY") && _otherDataProcess != nullptr) { filePartDone( @@ -688,11 +688,11 @@ void ApiWrap::startMainSession(FnMut done) { return data.vid().v; }); done(); - }).fail([=](RPCError &&result) { - error(std::move(result)); + }).fail([=](const RPCError &result) { + error(result); }).toDC(MTP::ShiftDcId(0, MTP::kExportDcShift)).send(); - }).fail([=](RPCError &&result) { - error(std::move(result)); + }).fail([=](const RPCError &result) { + error(result); }).send(); } @@ -1945,12 +1945,13 @@ void ApiWrap::filePartUnavailable() { base::take(_fileProcess)->done(QString()); } -void ApiWrap::error(RPCError &&error) { - _errors.fire(std::move(error)); +void ApiWrap::error(const RPCError &error) { + _errors.fire_copy(error); } void ApiWrap::error(const QString &text) { - error(MTP_rpc_error(MTP_int(0), MTP_string("API_ERROR: " + text))); + error(RPCError( + MTP_rpc_error(MTP_int(0), MTP_string("API_ERROR: " + text)))); } void ApiWrap::ioError(const Output::Result &result) { diff --git a/Telegram/SourceFiles/export/export_api_wrap.h b/Telegram/SourceFiles/export/export_api_wrap.h index 0250ab1c6..6ce4bfb96 100644 --- a/Telegram/SourceFiles/export/export_api_wrap.h +++ b/Telegram/SourceFiles/export/export_api_wrap.h @@ -203,7 +203,7 @@ private: const Data::FileLocation &location, int offset); - void error(RPCError &&error); + void error(const RPCError &error); void error(const QString &text); void ioError(const Output::Result &result); diff --git a/Telegram/SourceFiles/export/export_controller.cpp b/Telegram/SourceFiles/export/export_controller.cpp index 4f3240d1a..f442c7693 100644 --- a/Telegram/SourceFiles/export/export_controller.cpp +++ b/Telegram/SourceFiles/export/export_controller.cpp @@ -138,8 +138,8 @@ ControllerObject::ControllerObject( : _api(mtproto, weak.runner()) , _state(PasswordCheckState{}) { _api.errors( - ) | rpl::start_with_next([=](RPCError &&error) { - setState(ApiErrorState{ std::move(error) }); + ) | rpl::start_with_next([=](const RPCError &error) { + setState(ApiErrorState{ error }); }, _lifetime); _api.ioErrors( diff --git a/Telegram/SourceFiles/main/main_account.cpp b/Telegram/SourceFiles/main/main_account.cpp index 3fe0abb4e..e9b0b8141 100644 --- a/Telegram/SourceFiles/main/main_account.cpp +++ b/Telegram/SourceFiles/main/main_account.cpp @@ -420,18 +420,14 @@ void Account::startMtp(std::unique_ptr config) { _mtpFields.mainDcId = _mtp->mainDcId(); - _mtp->setUpdatesHandler(::rpcDone([=]( - const mtpPrime *from, - const mtpPrime *end) { - return checkForUpdates(from, end) - || checkForNewSession(from, end); - })); - _mtp->setGlobalFailHandler(::rpcFail([=](const RPCError &error) { + _mtp->setUpdatesHandler([=](const MTP::Response &message) { + checkForUpdates(message) || checkForNewSession(message); + }); + _mtp->setGlobalFailHandler([=](const RPCError &, const MTP::Response &) { if (const auto session = maybeSession()) { crl::on_main(session, [=] { logOut(); }); } - return true; - })); + }); _mtp->setStateChangedHandler([=](MTP::ShiftedDcId dc, int32 state) { if (dc == _mtp->mainDcId()) { Global::RefConnectionTypeChanged().notify(); @@ -468,18 +464,20 @@ void Account::startMtp(std::unique_ptr config) { _mtpValue = _mtp.get(); } -bool Account::checkForUpdates(const mtpPrime *from, const mtpPrime *end) { +bool Account::checkForUpdates(const MTP::Response &message) { auto updates = MTPUpdates(); - if (!updates.read(from, end)) { + auto from = message.reply.constData(); + if (!updates.read(from, from + message.reply.size())) { return false; } _mtpUpdates.fire(std::move(updates)); return true; } -bool Account::checkForNewSession(const mtpPrime *from, const mtpPrime *end) { +bool Account::checkForNewSession(const MTP::Response &message) { auto newSession = MTPNewSession(); - if (!newSession.read(from, end)) { + auto from = message.reply.constData(); + if (!newSession.read(from, from + message.reply.size())) { return false; } _mtpNewSessionCreated.fire({}); diff --git a/Telegram/SourceFiles/main/main_account.h b/Telegram/SourceFiles/main/main_account.h index 8c14b1793..20ede974c 100644 --- a/Telegram/SourceFiles/main/main_account.h +++ b/Telegram/SourceFiles/main/main_account.h @@ -122,8 +122,8 @@ private: std::unique_ptr settings); void watchProxyChanges(); void watchSessionChanges(); - bool checkForUpdates(const mtpPrime *from, const mtpPrime *end); - bool checkForNewSession(const mtpPrime *from, const mtpPrime *end); + bool checkForUpdates(const MTP::Response &message); + bool checkForNewSession(const MTP::Response &message); void destroyMtpKeys(MTP::AuthKeysList &&keys); void resetAuthorizationKeys(); diff --git a/Telegram/SourceFiles/mtproto/config_loader.cpp b/Telegram/SourceFiles/mtproto/config_loader.cpp index 949c44947..0177a4019 100644 --- a/Telegram/SourceFiles/mtproto/config_loader.cpp +++ b/Telegram/SourceFiles/mtproto/config_loader.cpp @@ -27,8 +27,8 @@ constexpr auto kSpecialRequestTimeoutMs = 6000; // 4 seconds timeout for it to w ConfigLoader::ConfigLoader( not_null instance, const QString &phone, - RPCDoneHandlerPtr onDone, - RPCFailHandlerPtr onFail) + Fn onDone, + FailHandler onFail) : _instance(instance) , _phone(phone) , _doneHandler(onDone) @@ -50,9 +50,18 @@ void ConfigLoader::load() { } mtpRequestId ConfigLoader::sendRequest(ShiftedDcId shiftedDcId) { + auto done = [done = _doneHandler](const Response &response) { + auto from = response.reply.constData(); + auto result = MTPConfig(); + if (!result.read(from, from + response.reply.size())) { + return false; + } + done(result); + return true; + }; return _instance->send( MTPhelp_GetConfig(), - base::duplicate(_doneHandler), + std::move(done), base::duplicate(_failHandler), shiftedDcId); } @@ -191,11 +200,17 @@ void ConfigLoader::sendSpecialRequest() { endpoint->secret); _specialEnumRequest = _instance->send( MTPhelp_GetConfig(), - rpcDone([weak](const MTPConfig &result) { + [weak](const Response &response) { + auto result = MTPConfig(); + auto from = response.reply.constData(); + if (!result.read(from, from + response.reply.size())) { + return false; + } if (const auto strong = weak.get()) { strong->specialConfigLoaded(result); } - }), + return true; + }, base::duplicate(_failHandler), _specialEnumCurrent); _triedSpecialEndpoints.push_back(*endpoint); diff --git a/Telegram/SourceFiles/mtproto/config_loader.h b/Telegram/SourceFiles/mtproto/config_loader.h index 39b7bc214..969295138 100644 --- a/Telegram/SourceFiles/mtproto/config_loader.h +++ b/Telegram/SourceFiles/mtproto/config_loader.h @@ -25,8 +25,8 @@ public: ConfigLoader( not_null instance, const QString &phone, - RPCDoneHandlerPtr onDone, - RPCFailHandlerPtr onFail); + Fn onDone, + FailHandler onFail); ~ConfigLoader(); void load(); @@ -68,8 +68,8 @@ private: mtpRequestId _specialEnumRequest = 0; QString _phone; - RPCDoneHandlerPtr _doneHandler; - RPCFailHandlerPtr _failHandler; + Fn _doneHandler; + FailHandler _failHandler; }; diff --git a/Telegram/SourceFiles/mtproto/dedicated_file_loader.h b/Telegram/SourceFiles/mtproto/dedicated_file_loader.h index 0f57ebef5..2b7e50922 100644 --- a/Telegram/SourceFiles/mtproto/dedicated_file_loader.h +++ b/Telegram/SourceFiles/mtproto/dedicated_file_loader.h @@ -19,10 +19,10 @@ class WeakInstance : private QObject, private base::Subscriber { public: explicit WeakInstance(base::weak_ptr session); - template + template void send( - const T &request, - Fn done, + const Request &request, + Fn done, Fn fail, ShiftedDcId dcId = 0); @@ -162,39 +162,44 @@ void StartDedicatedLoader( const QString &folder, Fn)> ready); -template +template void WeakInstance::send( - const T &request, - Fn done, + const Request &request, + Fn done, Fn fail, MTP::ShiftedDcId dcId) { - using Response = typename T::ResponseType; + using Result = typename Request::ResponseType; if (!valid()) { reportUnavailable(fail); return; } const auto onDone = crl::guard((QObject*)this, [=]( - const Response &result, - mtpRequestId requestId) { - if (removeRequest(requestId)) { + const Response &response) { + auto result = Result(); + auto from = response.reply.constData(); + if (!result.read(from, from + response.reply.size())) { + return false; + } + if (removeRequest(response.requestId)) { done(result); } + return true; }); const auto onFail = crl::guard((QObject*)this, [=]( const RPCError &error, - mtpRequestId requestId) { + const Response &response) { if (MTP::isDefaultHandledError(error)) { return false; } - if (removeRequest(requestId)) { + if (removeRequest(response.requestId)) { fail(error); } return true; }); const auto requestId = _instance->send( request, - rpcDone(onDone), - rpcFail(onFail), + std::move(onDone), + std::move(onFail), dcId); _requests.emplace(requestId, fail); } diff --git a/Telegram/SourceFiles/mtproto/mtp_instance.cpp b/Telegram/SourceFiles/mtproto/mtp_instance.cpp index 5da79cbe5..444e16622 100644 --- a/Telegram/SourceFiles/mtproto/mtp_instance.cpp +++ b/Telegram/SourceFiles/mtproto/mtp_instance.cpp @@ -119,7 +119,7 @@ public: void sendRequest( mtpRequestId requestId, SerializedRequest &&request, - RPCResponseHandler &&callbacks, + ResponseHandler &&callbacks, ShiftedDcId shiftedDcId, crl::time msCanWait, bool needsLayer, @@ -129,23 +129,30 @@ public: void storeRequest( mtpRequestId requestId, const SerializedRequest &request, - RPCResponseHandler &&callbacks); + ResponseHandler &&callbacks); SerializedRequest getRequest(mtpRequestId requestId); - void execCallback(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end); - bool hasCallbacks(mtpRequestId requestId); - void globalCallback(const mtpPrime *from, const mtpPrime *end); + [[nodiscard]] bool hasCallback(mtpRequestId requestId) const; + void processCallback(const Response &response); + void processUpdate(const Response &message); void onStateChange(ShiftedDcId shiftedDcId, int32 state); void onSessionReset(ShiftedDcId shiftedDcId); // return true if need to clean request data - bool rpcErrorOccured(mtpRequestId requestId, const RPCFailHandlerPtr &onFail, const RPCError &err); - inline bool rpcErrorOccured(mtpRequestId requestId, const RPCResponseHandler &handler, const RPCError &err) { - return rpcErrorOccured(requestId, handler.onFail, err); + bool rpcErrorOccured( + const Response &response, + const FailHandler &onFail, + const RPCError &error); + inline bool rpcErrorOccured( + const Response &response, + const ResponseHandler &handler, + const RPCError &error) { + return rpcErrorOccured(response, handler.fail, error); } - void setUpdatesHandler(RPCDoneHandlerPtr onDone); - void setGlobalFailHandler(RPCFailHandlerPtr onFail); + void setUpdatesHandler(Fn handler); + void setGlobalFailHandler( + Fn handler); void setStateChangedHandler(Fn handler); void setSessionResetHandler(Fn handler); void clearGlobalHandlers(); @@ -170,11 +177,15 @@ public: [[nodiscard]] rpl::lifetime &lifetime(); private: - void importDone(const MTPauth_Authorization &result, mtpRequestId requestId); - bool importFail(const RPCError &error, mtpRequestId requestId); - void exportDone(const MTPauth_ExportedAuthorization &result, mtpRequestId requestId); - bool exportFail(const RPCError &error, mtpRequestId requestId); - bool onErrorDefault(mtpRequestId requestId, const RPCError &error); + void importDone( + const MTPauth_Authorization &result, + const Response &response); + bool importFail(const RPCError &error, const Response &response); + void exportDone( + const MTPauth_ExportedAuthorization &result, + const Response &response); + bool exportFail(const RPCError &error, const Response &response); + bool onErrorDefault(const RPCError &error, const Response &response); void unpaused(); @@ -245,8 +256,8 @@ private: // holds target dcWithShift for auth export request std::map _authExportRequests; - std::map _parserMap; - QMutex _parserMapLock; + std::map _parserMap; + mutable QMutex _parserMapLock; std::map _requestMap; QReadWriteLock _requestMapLock; @@ -259,7 +270,8 @@ private: std::map> _authWaiters; - RPCResponseHandler _globalHandler; + Fn _updatesHandler; + Fn _globalFailHandler; Fn _stateChangedHandler; Fn _sessionResetHandler; @@ -450,8 +462,10 @@ void Instance::Private::requestConfig() { _configLoader = std::make_unique( _instance, _userPhone, - rpcDone([=](const MTPConfig &result) { configLoadDone(result); }), - rpcFail([=](const RPCError &error) { return configLoadFail(error); })); + [=](const MTPConfig &result) { configLoadDone(result); }, + [=](const RPCError &error, const Response &) { + return configLoadFail(error); + }); _configLoader->load(); } @@ -647,12 +661,13 @@ void Instance::Private::reInitConnection(DcId dcId) { } void Instance::Private::logout(Fn done) { - _instance->send(MTPauth_LogOut(), rpcDone([=] { - done(); - }), rpcFail([=] { + _instance->send(MTPauth_LogOut(), [=](Response) { done(); return true; - })); + }, [=](const RPCError&, Response) { + done(); + return true; + }); logoutGuestDcs(); } @@ -667,12 +682,14 @@ void Instance::Private::logoutGuestDcs() { continue; } const auto shiftedDcId = MTP::logoutDcId(dcId); - const auto requestId = _instance->send(MTPauth_LogOut(), rpcDone([=]( - mtpRequestId requestId) { - logoutGuestDone(requestId); - }), rpcFail([=](mtpRequestId requestId) { - return logoutGuestDone(requestId); - }), shiftedDcId); + const auto requestId = _instance->send(MTPauth_LogOut(), [=]( + const Response &response) { + logoutGuestDone(response.requestId); + return true; + }, [=](const RPCError &, const Response &response) { + logoutGuestDone(response.requestId); + return true; + }, shiftedDcId); _logoutGuestRequestIds.emplace(shiftedDcId, requestId); } } @@ -932,7 +949,7 @@ void Instance::Private::checkDelayedRequests() { void Instance::Private::sendRequest( mtpRequestId requestId, SerializedRequest &&request, - RPCResponseHandler &&callbacks, + ResponseHandler &&callbacks, ShiftedDcId shiftedDcId, crl::time msCanWait, bool needsLayer, @@ -980,8 +997,8 @@ void Instance::Private::unregisterRequest(mtpRequestId requestId) { void Instance::Private::storeRequest( mtpRequestId requestId, const SerializedRequest &request, - RPCResponseHandler &&callbacks) { - if (callbacks.onDone || callbacks.onFail) { + ResponseHandler &&callbacks) { + if (callbacks.done || callbacks.fail) { QMutexLocker locker(&_parserMapLock); _parserMap.emplace(requestId, std::move(callbacks)); } @@ -1003,53 +1020,58 @@ SerializedRequest Instance::Private::getRequest(mtpRequestId requestId) { return result; } +bool Instance::Private::hasCallback(mtpRequestId requestId) const { + QMutexLocker locker(&_parserMapLock); + auto it = _parserMap.find(requestId); + return (it != _parserMap.cend()); +} -void Instance::Private::execCallback( - mtpRequestId requestId, - const mtpPrime *from, - const mtpPrime *end) { - RPCResponseHandler h; +void Instance::Private::processCallback(const Response &response) { + const auto requestId = response.requestId; + ResponseHandler handler; { QMutexLocker locker(&_parserMapLock); auto it = _parserMap.find(requestId); if (it != _parserMap.cend()) { - h = it->second; + handler = std::move(it->second); _parserMap.erase(it); DEBUG_LOG(("RPC Info: found parser for request %1, trying to parse response...").arg(requestId)); } } - if (h.onDone || h.onFail) { + if (handler.done || handler.fail) { const auto handleError = [&](const RPCError &error) { DEBUG_LOG(("RPC Info: " "error received, code %1, type %2, description: %3" ).arg(error.code() ).arg(error.type() ).arg(error.description())); - if (rpcErrorOccured(requestId, h, error)) { + if (rpcErrorOccured(response, handler, error)) { unregisterRequest(requestId); } else { QMutexLocker locker(&_parserMapLock); - _parserMap.emplace(requestId, h); + _parserMap.emplace(requestId, std::move(handler)); } }; - if (from >= end) { + auto from = response.reply.constData(); + if (response.reply.isEmpty()) { handleError(RPCError::Local( "RESPONSE_PARSE_FAILED", "Empty response.")); } else if (*from == mtpc_rpc_error) { auto error = MTPRpcError(); - handleError(error.read(from, end) ? error : RPCError::Local( - "RESPONSE_PARSE_FAILED", - "Error parse failed.")); - } else { - if (h.onDone) { - if (!(*h.onDone)(requestId, from, end)) { - handleError(RPCError::Local( + handleError( + RPCError(error.read(from, from + response.reply.size()) + ? error + : RPCError::MTPLocal( "RESPONSE_PARSE_FAILED", - "Response parse failed.")); - } + "Error parse failed."))); + } else { + if (handler.done && !handler.done(response)) { + handleError(RPCError::Local( + "RESPONSE_PARSE_FAILED", + "Response parse failed.")); } unregisterRequest(requestId); } @@ -1059,18 +1081,10 @@ void Instance::Private::execCallback( } } -bool Instance::Private::hasCallbacks(mtpRequestId requestId) { - QMutexLocker locker(&_parserMapLock); - auto it = _parserMap.find(requestId); - return (it != _parserMap.cend()); -} - -void Instance::Private::globalCallback(const mtpPrime *from, const mtpPrime *end) { - if (!_globalHandler.onDone) { - return; +void Instance::Private::processUpdate(const Response &message) { + if (_updatesHandler) { + _updatesHandler(message); } - // Handle updates. - [[maybe_unused]] bool result = (*_globalHandler.onDone)(0, from, end); } void Instance::Private::onStateChange(ShiftedDcId dcWithShift, int32 state) { @@ -1085,25 +1099,40 @@ void Instance::Private::onSessionReset(ShiftedDcId dcWithShift) { } } -bool Instance::Private::rpcErrorOccured(mtpRequestId requestId, const RPCFailHandlerPtr &onFail, const RPCError &err) { // return true if need to clean request data - if (isDefaultHandledError(err)) { - if (onFail && (*onFail)(requestId, err)) { +bool Instance::Private::rpcErrorOccured( + const Response &response, + const FailHandler &onFail, + const RPCError &error) { // return true if need to clean request data + if (isDefaultHandledError(error)) { + if (onFail && onFail(error, response)) { return true; } } - if (onErrorDefault(requestId, err)) { + if (onErrorDefault(error, response)) { return false; } - LOG(("RPC Error: request %1 got fail with code %2, error %3%4").arg(requestId).arg(err.code()).arg(err.type()).arg(err.description().isEmpty() ? QString() : QString(": %1").arg(err.description()))); - onFail && (*onFail)(requestId, err); + LOG(("RPC Error: request %1 got fail with code %2, error %3%4" + ).arg(response.requestId + ).arg(error.code() + ).arg(error.type() + ).arg(error.description().isEmpty() + ? QString() + : QString(": %1").arg(error.description()))); + if (onFail) { + onFail(error, response); + } return true; } -void Instance::Private::importDone(const MTPauth_Authorization &result, mtpRequestId requestId) { - const auto shiftedDcId = queryRequestByDc(requestId); +void Instance::Private::importDone( + const MTPauth_Authorization &result, + const Response &response) { + const auto shiftedDcId = queryRequestByDc(response.requestId); if (!shiftedDcId) { - LOG(("MTP Error: auth import request not found in requestsByDC, requestId: %1").arg(requestId)); + LOG(("MTP Error: " + "auth import request not found in requestsByDC, requestId: %1" + ).arg(response.requestId)); // // Don't log out on export/import problems, perhaps this is a server side error. // @@ -1111,8 +1140,8 @@ void Instance::Private::importDone(const MTPauth_Authorization &result, mtpReque // "AUTH_IMPORT_FAIL", // QString("did not find import request in requestsByDC, " // "request %1").arg(requestId)); - //if (_globalHandler.onFail && hasAuthorization()) { - // (*_globalHandler.onFail)(requestId, error); // auth failed in main dc + //if (_globalFailHandler && hasAuthorization()) { + // _globalFailHandler(error, response); // auth failed in main dc //} return; } @@ -1144,22 +1173,30 @@ void Instance::Private::importDone(const MTPauth_Authorization &result, mtpReque } } -bool Instance::Private::importFail(const RPCError &error, mtpRequestId requestId) { - if (isDefaultHandledError(error)) return false; +bool Instance::Private::importFail( + const RPCError &error, + const Response &response) { + if (isDefaultHandledError(error)) { + return false; + } // // Don't log out on export/import problems, perhaps this is a server side error. // - //if (_globalHandler.onFail && hasAuthorization()) { - // (*_globalHandler.onFail)(requestId, error); // auth import failed + //if (_globalFailHandler && hasAuthorization()) { + // _globalFailHandler(error, response); // auth import failed //} return true; } -void Instance::Private::exportDone(const MTPauth_ExportedAuthorization &result, mtpRequestId requestId) { - auto it = _authExportRequests.find(requestId); +void Instance::Private::exportDone( + const MTPauth_ExportedAuthorization &result, + const Response &response) { + auto it = _authExportRequests.find(response.requestId); if (it == _authExportRequests.cend()) { - LOG(("MTP Error: auth export request target dcWithShift not found, requestId: %1").arg(requestId)); + LOG(("MTP Error: " + "auth export request target dcWithShift not found, requestId: %1" + ).arg(response.requestId)); // // Don't log out on export/import problems, perhaps this is a server side error. // @@ -1167,46 +1204,62 @@ void Instance::Private::exportDone(const MTPauth_ExportedAuthorization &result, // "AUTH_IMPORT_FAIL", // QString("did not find target dcWithShift, request %1" // ).arg(requestId)); - //if (_globalHandler.onFail && hasAuthorization()) { - // (*_globalHandler.onFail)(requestId, error); // auth failed in main dc + //if (_globalFailHandler && hasAuthorization()) { + // _globalFailHandler(error, response); // auth failed in main dc //} return; } auto &data = result.c_auth_exportedAuthorization(); - _instance->send(MTPauth_ImportAuthorization(data.vid(), data.vbytes()), rpcDone([this](const MTPauth_Authorization &result, mtpRequestId requestId) { - importDone(result, requestId); - }), rpcFail([this](const RPCError &error, mtpRequestId requestId) { - return importFail(error, requestId); - }), it->second); - _authExportRequests.erase(requestId); + _instance->send(MTPauth_ImportAuthorization( + data.vid(), + data.vbytes() + ), [this](const Response &response) { + auto result = MTPauth_Authorization(); + auto from = response.reply.constData(); + if (!result.read(from, from + response.reply.size())) { + return false; + } + importDone(result, response); + return true; + }, [this](const RPCError &error, const Response &response) { + return importFail(error, response); + }, it->second); + _authExportRequests.erase(response.requestId); } -bool Instance::Private::exportFail(const RPCError &error, mtpRequestId requestId) { - if (isDefaultHandledError(error)) return false; +bool Instance::Private::exportFail( + const RPCError &error, + const Response &response) { + if (isDefaultHandledError(error)) { + return false; + } - auto it = _authExportRequests.find(requestId); + auto it = _authExportRequests.find(response.requestId); if (it != _authExportRequests.cend()) { _authWaiters[BareDcId(it->second)].clear(); } // // Don't log out on export/import problems, perhaps this is a server side error. // - //if (_globalHandler.onFail && hasAuthorization()) { - // (*_globalHandler.onFail)(requestId, error); // auth failed in main dc + //if (_globalFailHandler && hasAuthorization()) { + // _globalFailHandler(error, response); // auth failed in main dc //} return true; } -bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &error) { - auto &err(error.type()); - auto code = error.code(); - if (!isFloodError(error) && err != qstr("AUTH_KEY_UNREGISTERED")) { +bool Instance::Private::onErrorDefault( + const RPCError &error, + const Response &response) { + const auto requestId = response.requestId; + const auto &type = error.type(); + const auto code = error.code(); + if (!isFloodError(error) && type != qstr("AUTH_KEY_UNREGISTERED")) { int breakpoint = 0; } - auto badGuestDc = (code == 400) && (err == qsl("FILE_ID_INVALID")); + auto badGuestDc = (code == 400) && (type == qsl("FILE_ID_INVALID")); QRegularExpressionMatch m; - if ((m = QRegularExpression("^(FILE|PHONE|NETWORK|USER)_MIGRATE_(\\d+)$").match(err)).hasMatch()) { + if ((m = QRegularExpression("^(FILE|PHONE|NETWORK|USER)_MIGRATE_(\\d+)$").match(type)).hasMatch()) { if (!requestId) return false; auto dcWithShift = ShiftedDcId(0); @@ -1228,11 +1281,19 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e //DEBUG_LOG(("MTP Info: importing auth to dc %1").arg(newdcWithShift)); //auto &waiters(_authWaiters[newdcWithShift]); //if (waiters.empty()) { - // auto exportRequestId = _instance->send(MTPauth_ExportAuthorization(MTP_int(newdcWithShift)), rpcDone([this](const MTPauth_ExportedAuthorization &result, mtpRequestId requestId) { - // exportDone(result, requestId); - // }), rpcFail([this](const RPCError &error, mtpRequestId requestId) { - // return exportFail(error, requestId); - // })); + // auto exportRequestId = _instance->send(MTPauth_ExportAuthorization( + // MTP_int(newdcWithShift) + // ), [this](const Response &response) { + // auto result = MTPauth_ExportedAuthorization(); + // auto from = response.reply.constData(); + // if (!result.read(from, from + response.reply.size())) { + // return false; + // } + // exportDone(result, response); + // return true; + // }, [this](const RPCError &error, const Response &response) { + // return exportFail(error, response); + // }); // _authExportRequests.emplace(exportRequestId, newdcWithShift); //} //waiters.push_back(requestId); @@ -1260,7 +1321,7 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e (dcWithShift < 0) ? -newdcWithShift : newdcWithShift); session->sendPrepared(request); return true; - } else if (code < 0 || code >= 500 || (m = QRegularExpression("^FLOOD_WAIT_(\\d+)$").match(err)).hasMatch()) { + } else if (code < 0 || code >= 500 || (m = QRegularExpression("^FLOOD_WAIT_(\\d+)$").match(type)).hasMatch()) { if (!requestId) return false; int32 secs = 1; @@ -1286,7 +1347,7 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e checkDelayedRequests(); return true; - } else if ((code == 401 && err != "AUTH_KEY_PERM_EMPTY") + } else if ((code == 401 && type != qstr("AUTH_KEY_PERM_EMPTY")) || (badGuestDc && _badGuestDcRequests.find(requestId) == _badGuestDcRequests.cend())) { auto dcWithShift = ShiftedDcId(0); if (const auto shiftedDcId = queryRequestByDc(requestId)) { @@ -1296,26 +1357,36 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e } auto newdc = BareDcId(qAbs(dcWithShift)); if (!newdc || newdc == mainDcId()) { - if (!badGuestDc && _globalHandler.onFail) { - (*_globalHandler.onFail)(requestId, error); // auth failed in main dc + if (!badGuestDc && _globalFailHandler) { + _globalFailHandler(error, response); // auth failed in main dc } return false; } - DEBUG_LOG(("MTP Info: importing auth to dcWithShift %1").arg(dcWithShift)); + DEBUG_LOG(("MTP Info: importing auth to dcWithShift %1" + ).arg(dcWithShift)); auto &waiters(_authWaiters[newdc]); if (!waiters.size()) { - auto exportRequestId = _instance->send(MTPauth_ExportAuthorization(MTP_int(newdc)), rpcDone([this](const MTPauth_ExportedAuthorization &result, mtpRequestId requestId) { - exportDone(result, requestId); - }), rpcFail([this](const RPCError &error, mtpRequestId requestId) { - return exportFail(error, requestId); - })); + auto exportRequestId = _instance->send(MTPauth_ExportAuthorization( + MTP_int(newdc) + ), [this](const Response &response) { + auto result = MTPauth_ExportedAuthorization(); + auto from = response.reply.constData(); + if (!result.read(from, from + response.reply.size())) { + return false; + } + exportDone(result, response); + return true; + }, [this](const RPCError &error, const Response &response) { + return exportFail(error, response); + }); _authExportRequests.emplace(exportRequestId, abs(dcWithShift)); } waiters.push_back(requestId); if (badGuestDc) _badGuestDcRequests.insert(requestId); return true; - } else if (err == qstr("CONNECTION_NOT_INITED") || err == qstr("CONNECTION_LAYER_INVALID")) { + } else if (type == qstr("CONNECTION_NOT_INITED") + || type == qstr("CONNECTION_LAYER_INVALID")) { SerializedRequest request; { QReadLocker locker(&_requestMapLock); @@ -1338,9 +1409,9 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e request->needsLayer = true; session->sendPrepared(request); return true; - } else if (err == qstr("CONNECTION_LANG_CODE_INVALID")) { + } else if (type == qstr("CONNECTION_LANG_CODE_INVALID")) { Lang::CurrentCloudManager().resetToDefault(); - } else if (err == qstr("MSG_WAIT_FAILED")) { + } else if (type == qstr("MSG_WAIT_FAILED")) { SerializedRequest request; { QReadLocker locker(&_requestMapLock); @@ -1514,15 +1585,16 @@ void Instance::Private::scheduleKeyDestroy(ShiftedDcId shiftedDcId) { if (dcOptions().dcType(shiftedDcId) == DcType::Cdn) { performKeyDestroy(shiftedDcId); } else { - _instance->send(MTPauth_LogOut(), rpcDone([=](const MTPBool &) { + _instance->send(MTPauth_LogOut(), [=](const Response &) { performKeyDestroy(shiftedDcId); - }), rpcFail([=](const RPCError &error) { + return true; + }, [=](const RPCError &error, const Response &) { if (isDefaultHandledError(error)) { return false; } performKeyDestroy(shiftedDcId); return true; - }), shiftedDcId); + }, shiftedDcId); } } @@ -1539,21 +1611,29 @@ void Instance::Private::keyWasPossiblyDestroyed(ShiftedDcId shiftedDcId) { void Instance::Private::performKeyDestroy(ShiftedDcId shiftedDcId) { Expects(isKeysDestroyer()); - _instance->send(MTPDestroy_auth_key(), rpcDone([=](const MTPDestroyAuthKeyRes &result) { - switch (result.type()) { - case mtpc_destroy_auth_key_ok: LOG(("MTP Info: key %1 destroyed.").arg(shiftedDcId)); break; - case mtpc_destroy_auth_key_fail: { - LOG(("MTP Error: key %1 destruction fail, leave it for now.").arg(shiftedDcId)); - killSession(shiftedDcId); - } break; - case mtpc_destroy_auth_key_none: LOG(("MTP Info: key %1 already destroyed.").arg(shiftedDcId)); break; + _instance->send(MTPDestroy_auth_key(), [=](const Response &response) { + auto result = MTPDestroyAuthKeyRes(); + auto from = response.reply.constData(); + if (!result.read(from, from + response.reply.size())) { + return false; } - _instance->keyWasPossiblyDestroyed(shiftedDcId); - }), rpcFail([=](const RPCError &error) { - LOG(("MTP Error: key %1 destruction resulted in error: %2").arg(shiftedDcId).arg(error.type())); + result.match([&](const MTPDdestroy_auth_key_ok &) { + LOG(("MTP Info: key %1 destroyed.").arg(shiftedDcId)); + }, [&](const MTPDdestroy_auth_key_fail &) { + LOG(("MTP Error: key %1 destruction fail, leave it for now." + ).arg(shiftedDcId)); + killSession(shiftedDcId); + }, [&](const MTPDdestroy_auth_key_none &) { + LOG(("MTP Info: key %1 already destroyed.").arg(shiftedDcId)); + }); _instance->keyWasPossiblyDestroyed(shiftedDcId); return true; - }), shiftedDcId); + }, [=](const RPCError &error, const Response &response) { + LOG(("MTP Error: key %1 destruction resulted in error: %2" + ).arg(shiftedDcId).arg(error.type())); + _instance->keyWasPossiblyDestroyed(shiftedDcId); + return true; + }, shiftedDcId); } void Instance::Private::completedKeyDestroy(ShiftedDcId shiftedDcId) { @@ -1582,27 +1662,31 @@ void Instance::Private::keyDestroyedOnServer( restart(shiftedDcId); } -void Instance::Private::setUpdatesHandler(RPCDoneHandlerPtr onDone) { - _globalHandler.onDone = onDone; +void Instance::Private::setUpdatesHandler( + Fn handler) { + _updatesHandler = std::move(handler); } -void Instance::Private::setGlobalFailHandler(RPCFailHandlerPtr onFail) { - _globalHandler.onFail = onFail; +void Instance::Private::setGlobalFailHandler( + Fn handler) { + _globalFailHandler = std::move(handler); } -void Instance::Private::setStateChangedHandler(Fn handler) { +void Instance::Private::setStateChangedHandler( + Fn handler) { _stateChangedHandler = std::move(handler); } -void Instance::Private::setSessionResetHandler(Fn handler) { +void Instance::Private::setSessionResetHandler( + Fn handler) { _sessionResetHandler = std::move(handler); } void Instance::Private::clearGlobalHandlers() { - setUpdatesHandler(RPCDoneHandlerPtr()); - setGlobalFailHandler(RPCFailHandlerPtr()); - setStateChangedHandler(Fn()); - setSessionResetHandler(Fn()); + setUpdatesHandler(nullptr); + setGlobalFailHandler(nullptr); + setStateChangedHandler(nullptr); + setSessionResetHandler(nullptr); } void Instance::Private::prepareToDestroy() { @@ -1806,19 +1890,22 @@ QString Instance::systemVersion() const { return _private->systemVersion(); } -void Instance::setUpdatesHandler(RPCDoneHandlerPtr onDone) { - _private->setUpdatesHandler(onDone); +void Instance::setUpdatesHandler(Fn handler) { + _private->setUpdatesHandler(std::move(handler)); } -void Instance::setGlobalFailHandler(RPCFailHandlerPtr onFail) { - _private->setGlobalFailHandler(onFail); +void Instance::setGlobalFailHandler( + Fn handler) { + _private->setGlobalFailHandler(std::move(handler)); } -void Instance::setStateChangedHandler(Fn handler) { +void Instance::setStateChangedHandler( + Fn handler) { _private->setStateChangedHandler(std::move(handler)); } -void Instance::setSessionResetHandler(Fn handler) { +void Instance::setSessionResetHandler( + Fn handler) { _private->setSessionResetHandler(std::move(handler)); } @@ -1834,20 +1921,23 @@ void Instance::onSessionReset(ShiftedDcId shiftedDcId) { _private->onSessionReset(shiftedDcId); } -void Instance::execCallback(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) { - _private->execCallback(requestId, from, end); +bool Instance::hasCallback(mtpRequestId requestId) const { + return _private->hasCallback(requestId); } -bool Instance::hasCallbacks(mtpRequestId requestId) { - return _private->hasCallbacks(requestId); +void Instance::processCallback(const Response &response) { + _private->processCallback(response); } -void Instance::globalCallback(const mtpPrime *from, const mtpPrime *end) { - _private->globalCallback(from, end); +void Instance::processUpdate(const Response &message) { + _private->processUpdate(message); } -bool Instance::rpcErrorOccured(mtpRequestId requestId, const RPCFailHandlerPtr &onFail, const RPCError &err) { - return _private->rpcErrorOccured(requestId, onFail, err); +bool Instance::rpcErrorOccured( + const Response &response, + const FailHandler &onFail, + const RPCError &error) { + return _private->rpcErrorOccured(response, onFail, error); } bool Instance::isKeysDestroyer() const { @@ -1865,7 +1955,7 @@ void Instance::keyDestroyedOnServer(ShiftedDcId shiftedDcId, uint64 keyId) { void Instance::sendRequest( mtpRequestId requestId, SerializedRequest &&request, - RPCResponseHandler &&callbacks, + ResponseHandler &&callbacks, ShiftedDcId shiftedDcId, crl::time msCanWait, bool needsLayer, diff --git a/Telegram/SourceFiles/mtproto/mtp_instance.h b/Telegram/SourceFiles/mtproto/mtp_instance.h index 55af28fb5..8598d19ba 100644 --- a/Telegram/SourceFiles/mtproto/mtp_instance.h +++ b/Telegram/SourceFiles/mtproto/mtp_instance.h @@ -102,21 +102,26 @@ public: void reInitConnection(DcId dcId); void logout(Fn done); - void setUpdatesHandler(RPCDoneHandlerPtr onDone); - void setGlobalFailHandler(RPCFailHandlerPtr onFail); - void setStateChangedHandler(Fn handler); + void setUpdatesHandler(Fn handler); + void setGlobalFailHandler( + Fn handler); + void setStateChangedHandler( + Fn handler); void setSessionResetHandler(Fn handler); void clearGlobalHandlers(); void onStateChange(ShiftedDcId shiftedDcId, int32 state); void onSessionReset(ShiftedDcId shiftedDcId); - void execCallback(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end); - bool hasCallbacks(mtpRequestId requestId); - void globalCallback(const mtpPrime *from, const mtpPrime *end); + [[nodiscard]] bool hasCallback(mtpRequestId requestId) const; + void processCallback(const Response &response); + void processUpdate(const Response &message); // return true if need to clean request data - bool rpcErrorOccured(mtpRequestId requestId, const RPCFailHandlerPtr &onFail, const RPCError &err); + bool rpcErrorOccured( + const Response &response, + const FailHandler &onFail, + const RPCError &err); // Thread-safe. bool isKeysDestroyer() const; @@ -141,7 +146,7 @@ public: template mtpRequestId send( const Request &request, - RPCResponseHandler &&callbacks = {}, + ResponseHandler &&callbacks = {}, ShiftedDcId shiftedDcId = 0, crl::time msCanWait = 0, mtpRequestId afterRequestId = 0) { @@ -159,14 +164,14 @@ public: template mtpRequestId send( const Request &request, - RPCDoneHandlerPtr &&onDone, - RPCFailHandlerPtr &&onFail = nullptr, + DoneHandler &&onDone, + FailHandler &&onFail = nullptr, ShiftedDcId shiftedDcId = 0, crl::time msCanWait = 0, mtpRequestId afterRequestId = 0) { return send( request, - RPCResponseHandler(std::move(onDone), std::move(onFail)), + ResponseHandler{ std::move(onDone), std::move(onFail) }, shiftedDcId, msCanWait, afterRequestId); @@ -191,7 +196,7 @@ public: void sendSerialized( mtpRequestId requestId, details::SerializedRequest &&request, - RPCResponseHandler &&callbacks, + ResponseHandler &&callbacks, ShiftedDcId shiftedDcId, crl::time msCanWait, mtpRequestId afterRequestId) { @@ -218,7 +223,7 @@ private: void sendRequest( mtpRequestId requestId, details::SerializedRequest &&request, - RPCResponseHandler &&callbacks, + ResponseHandler &&callbacks, ShiftedDcId shiftedDcId, crl::time msCanWait, bool needsLayer, diff --git a/Telegram/SourceFiles/mtproto/mtproto_concurrent_sender.cpp b/Telegram/SourceFiles/mtproto/mtproto_concurrent_sender.cpp index 50a2d95ab..a03658066 100644 --- a/Telegram/SourceFiles/mtproto/mtproto_concurrent_sender.cpp +++ b/Telegram/SourceFiles/mtproto/mtproto_concurrent_sender.cpp @@ -13,90 +13,61 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL namespace MTP { -class ConcurrentSender::RPCDoneHandler : public RPCAbstractDoneHandler { +class ConcurrentSender::HandlerMaker final { public: - RPCDoneHandler( + static ::MTP::DoneHandler MakeDone( not_null sender, Fn)> runner); - - bool operator()( - mtpRequestId requestId, - const mtpPrime *from, - const mtpPrime *end) override; - -private: - base::weak_ptr _weak; - Fn)> _runner; - -}; - -class ConcurrentSender::RPCFailHandler : public RPCAbstractFailHandler { -public: - RPCFailHandler( + static ::MTP::FailHandler MakeFail( not_null sender, Fn)> runner, FailSkipPolicy skipPolicy); - - bool operator()( - mtpRequestId requestId, - const RPCError &error) override; - -private: - base::weak_ptr _weak; - Fn)> _runner; - FailSkipPolicy _skipPolicy = FailSkipPolicy::Simple; - }; -ConcurrentSender::RPCDoneHandler::RPCDoneHandler( - not_null sender, - Fn)> runner) -: _weak(sender) -, _runner(std::move(runner)) { +::MTP::DoneHandler ConcurrentSender::HandlerMaker::MakeDone( + not_null sender, + Fn)> runner) { + return [ + weak = base::make_weak(sender.get()), + runner = std::move(runner) + ](const Response &response) mutable { + runner([=]() mutable { + if (const auto strong = weak.get()) { + strong->senderRequestDone( + response.requestId, + bytes::make_span(response.reply)); + } + }); + return true; + }; } -bool ConcurrentSender::RPCDoneHandler::operator()( - mtpRequestId requestId, - const mtpPrime *from, - const mtpPrime *end) { - auto response = gsl::make_span( - from, - end - from); - _runner([=, weak = _weak, moved = bytes::make_vector(response)]() mutable { - if (const auto strong = weak.get()) { - strong->senderRequestDone(requestId, std::move(moved)); +::MTP::FailHandler ConcurrentSender::HandlerMaker::MakeFail( + not_null sender, + Fn)> runner, + FailSkipPolicy skipPolicy) { + return [ + weak = base::make_weak(sender.get()), + runner = std::move(runner), + skipPolicy + ](const RPCError &error, const Response &response) mutable { + if (skipPolicy == FailSkipPolicy::Simple) { + if (MTP::isDefaultHandledError(error)) { + return false; + } + } else if (skipPolicy == FailSkipPolicy::HandleFlood) { + if (MTP::isDefaultHandledError(error) + && !MTP::isFloodError(error)) { + return false; + } } - }); - return true; -} - -ConcurrentSender::RPCFailHandler::RPCFailHandler( - not_null sender, - Fn)> runner, - FailSkipPolicy skipPolicy) -: _weak(sender) -, _runner(std::move(runner)) -, _skipPolicy(skipPolicy) { -} - -bool ConcurrentSender::RPCFailHandler::operator()( - mtpRequestId requestId, - const RPCError &error) { - if (_skipPolicy == FailSkipPolicy::Simple) { - if (MTP::isDefaultHandledError(error)) { - return false; - } - } else if (_skipPolicy == FailSkipPolicy::HandleFlood) { - if (MTP::isDefaultHandledError(error) && !MTP::isFloodError(error)) { - return false; - } - } - _runner([=, weak = _weak, error = error]() mutable { - if (const auto strong = weak.get()) { - strong->senderRequestFail(requestId, std::move(error)); - } - }); - return true; + runner([=, requestId = response.requestId]() mutable { + if (const auto strong = weak.get()) { + strong->senderRequestFail(requestId, error); + } + }); + return true; + }; } template @@ -147,8 +118,8 @@ mtpRequestId ConcurrentSender::RequestBuilder::send() { _sender->with_instance([ =, request = std::move(_serialized), - done = std::make_shared(_sender, _sender->_runner), - fail = std::make_shared( + done = HandlerMaker::MakeDone(_sender, _sender->_runner), + fail = HandlerMaker::MakeFail( _sender, _sender->_runner, _failSkipPolicy) @@ -156,7 +127,7 @@ mtpRequestId ConcurrentSender::RequestBuilder::send() { instance->sendSerialized( requestId, std::move(request), - RPCResponseHandler(std::move(done), std::move(fail)), + ResponseHandler{ std::move(done), std::move(fail) }, dcId, msCanWait, afterRequestId); @@ -198,9 +169,9 @@ void ConcurrentSender::senderRequestDone( void ConcurrentSender::senderRequestFail( mtpRequestId requestId, - RPCError &&error) { + const RPCError &error) { if (auto handlers = _requests.take(requestId)) { - handlers->fail(requestId, std::move(error)); + handlers->fail(requestId, error); } } diff --git a/Telegram/SourceFiles/mtproto/mtproto_concurrent_sender.h b/Telegram/SourceFiles/mtproto/mtproto_concurrent_sender.h index ca47ef7f7..0c0971805 100644 --- a/Telegram/SourceFiles/mtproto/mtproto_concurrent_sender.h +++ b/Telegram/SourceFiles/mtproto/mtproto_concurrent_sender.h @@ -40,7 +40,7 @@ class ConcurrentSender : public base::has_weak_ptr { bytes::const_span result)>; using FailHandler = FnMut; + const RPCError &error)>; struct Handlers { DoneHandler done; FailHandler fail; @@ -95,7 +95,7 @@ public: template class SpecificRequestBuilder : public RequestBuilder { public: - using Response = typename Request::ResponseType; + using Result = typename Request::ResponseType; SpecificRequestBuilder( const SpecificRequestBuilder &other) = delete; @@ -115,18 +115,14 @@ public: // Allow code completion to show response type. [[nodiscard]] SpecificRequestBuilder &done(FnMut &&handler); [[nodiscard]] SpecificRequestBuilder &done( - FnMut &&handler); + FnMut &&handler); [[nodiscard]] SpecificRequestBuilder &done( - FnMut &&handler); - [[nodiscard]] SpecificRequestBuilder &done(FnMut &&handler); - [[nodiscard]] SpecificRequestBuilder &fail(FnMut &&handler); + FnMut &&handler); + [[nodiscard]] SpecificRequestBuilder &fail(Fn &&handler); [[nodiscard]] SpecificRequestBuilder &fail( - FnMut &&handler); + Fn &&handler); [[nodiscard]] SpecificRequestBuilder &fail( - FnMut &&handler); - [[nodiscard]] SpecificRequestBuilder &fail( - FnMut &&handler); + Fn &&handler); #else // !MTP_SENDER_USE_GENERIC_HANDLERS template [[nodiscard]] SpecificRequestBuilder &done(Handler &&handler); @@ -178,10 +174,8 @@ public: ~ConcurrentSender(); private: - class RPCDoneHandler; - friend class RPCDoneHandler; - class RPCFailHandler; - friend class RPCFailHandler; + class HandlerMaker; + friend class HandlerMaker; friend class RequestBuilder; friend class SentRequestWrap; @@ -191,7 +185,7 @@ private: bytes::const_span result); void senderRequestFail( mtpRequestId requestId, - RPCError &&error); + const RPCError &error); void senderRequestCancel(mtpRequestId requestId); void senderRequestCancelAll(); void senderRequestDetach(mtpRequestId requestId); @@ -202,7 +196,7 @@ private: }; -template +template void ConcurrentSender::RequestBuilder::setDoneHandler( InvokeFullDone &&invoke ) noexcept { @@ -211,11 +205,11 @@ void ConcurrentSender::RequestBuilder::setDoneHandler( bytes::const_span result) mutable { auto from = reinterpret_cast(result.data()); const auto end = from + result.size() / sizeof(mtpPrime); - Response data; + Result data; if (!data.read(from, end)) { return false; } - std::move(handler)(requestId, std::move(data)); + handler(requestId, std::move(data)); return true; }; } @@ -254,33 +248,21 @@ auto ConcurrentSender::SpecificRequestBuilder::afterDelay( // Allow code completion to show response type. template auto ConcurrentSender::SpecificRequestBuilder::done( - FnMut &&handler + FnMut &&handler ) -> SpecificRequestBuilder & { - setDoneHandler([handler = std::move(handler)]( + setDoneHandler([handler = std::move(handler)]( mtpRequestId requestId, - Response &&result) mutable { - std::move(handler)(std::move(result)); + Result &&result) mutable { + handler(std::move(result)); }); return *this; } template auto ConcurrentSender::SpecificRequestBuilder::done( - FnMut &&handler + FnMut &&handler ) -> SpecificRequestBuilder & { - setDoneHandler(std::move(handler)); - return *this; -} - -template -auto ConcurrentSender::SpecificRequestBuilder::done( - FnMut &&handler -) -> SpecificRequestBuilder & { - setDoneHandler([handler = std::move(handler)]( - mtpRequestId requestId, - Response &&result) mutable { - std::move(handler)(requestId); - }); + setDoneHandler(std::move(handler)); return *this; } @@ -288,9 +270,9 @@ template auto ConcurrentSender::SpecificRequestBuilder::done( FnMut &&handler ) -> SpecificRequestBuilder & { - setDoneHandler([handler = std::move(handler)]( + setDoneHandler([handler = std::move(handler)]( mtpRequestId requestId, - Response &&result) mutable { + Result &&result) mutable { std::move(handler)(); }); return *this; @@ -298,19 +280,19 @@ auto ConcurrentSender::SpecificRequestBuilder::done( template auto ConcurrentSender::SpecificRequestBuilder::fail( - FnMut &&handler + Fn &&handler ) -> SpecificRequestBuilder & { setFailHandler([handler = std::move(handler)]( mtpRequestId requestId, - RPCError &&error) mutable { - std::move(handler)(std::move(error)); + const RPCError &error) { + handler(error); }); return *this; } template auto ConcurrentSender::SpecificRequestBuilder::fail( - FnMut &&handler + Fn &&handler ) -> SpecificRequestBuilder & { setFailHandler(std::move(handler)); return *this; @@ -318,24 +300,12 @@ auto ConcurrentSender::SpecificRequestBuilder::fail( template auto ConcurrentSender::SpecificRequestBuilder::fail( - FnMut &&handler + Fn &&handler ) -> SpecificRequestBuilder & { setFailHandler([handler = std::move(handler)]( mtpRequestId requestId, - RPCError &&error) mutable { - std::move(handler)(requestId); - }); - return *this; -} - -template -auto ConcurrentSender::SpecificRequestBuilder::fail( - FnMut &&handler -) -> SpecificRequestBuilder & { - setFailHandler([handler = std::move(handler)]( - mtpRequestId requestId, - RPCError &&error) mutable { - std::move(handler)(); + const RPCError &error) { + handler(); }); return *this; } @@ -345,38 +315,29 @@ template auto ConcurrentSender::SpecificRequestBuilder::done( Handler &&handler ) -> SpecificRequestBuilder & { - using Response = typename Request::ResponseType; + using Result = typename Request::ResponseType; constexpr auto takesFull = rpl::details::is_callable_plain_v< Handler, mtpRequestId, - Response>; + Result>; constexpr auto takesResponse = rpl::details::is_callable_plain_v< Handler, - Response>; - constexpr auto takesRequestId = rpl::details::is_callable_plain_v< - Handler, - mtpRequestId>; + Result>; constexpr auto takesNone = rpl::details::is_callable_plain_v; if constexpr (takesFull) { - setDoneHandler(std::forward(handler)); + setDoneHandler(std::forward(handler)); } else if constexpr (takesResponse) { - setDoneHandler([handler = std::forward(handler)]( + setDoneHandler([handler = std::forward(handler)]( mtpRequestId requestId, - Response &&result) mutable { - std::move(handler)(std::move(result)); - }); - } else if constexpr (takesRequestId) { - setDoneHandler([handler = std::forward(handler)]( - mtpRequestId requestId, - Response &&result) mutable { - std::move(handler)(requestId); + Result &&result) mutable { + handler(std::move(result)); }); } else if constexpr (takesNone) { - setDoneHandler([handler = std::forward(handler)]( + setDoneHandler([handler = std::forward(handler)]( mtpRequestId requestId, - Response &&result) mutable { - std::move(handler)(); + Result &&result) mutable { + handler(); }); } else { static_assert(false_t(Handler{}), "Bad done handler."); @@ -396,9 +357,6 @@ auto ConcurrentSender::SpecificRequestBuilder::fail( constexpr auto takesError = rpl::details::is_callable_plain_v< Handler, RPCError>; - constexpr auto takesRequestId = rpl::details::is_callable_plain_v< - Handler, - mtpRequestId>; constexpr auto takesNone = rpl::details::is_callable_plain_v; if constexpr (takesFull) { @@ -406,20 +364,14 @@ auto ConcurrentSender::SpecificRequestBuilder::fail( } else if constexpr (takesError) { setFailHandler([handler = std::forward(handler)]( mtpRequestId requestId, - RPCError &&error) mutable { - std::move(handler)(std::move(error)); - }); - } else if constexpr (takesRequestId) { - setFailHandler([handler = std::forward(handler)]( - mtpRequestId requestId, - RPCError &&error) mutable { - std::move(handler)(requestId); + const RPCError &error) { + handler(error); }); } else if constexpr (takesNone) { setFailHandler([handler = std::forward(handler)]( mtpRequestId requestId, - RPCError &&error) mutable { - std::move(handler)(); + const RPCError &error) { + handler(); }); } else { static_assert(false_t(Handler{}), "Bad fail handler."); diff --git a/Telegram/SourceFiles/mtproto/mtproto_rpc_sender.cpp b/Telegram/SourceFiles/mtproto/mtproto_rpc_sender.cpp index 7d52baf7a..203a759af 100644 --- a/Telegram/SourceFiles/mtproto/mtproto_rpc_sender.cpp +++ b/Telegram/SourceFiles/mtproto/mtproto_rpc_sender.cpp @@ -9,6 +9,20 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL #include +namespace { + +[[nodiscard]] MTPrpcError ParseError(const mtpBuffer &reply) { + auto result = MTPRpcError(); + auto from = reply.constData(); + return result.read(from, from + reply.size()) + ? result + : RPCError::MTPLocal( + "RESPONSE_PARSE_FAILED", + "Error parse failed."); +} + +} // namespace + RPCError::RPCError(const MTPrpcError &error) : _code(error.c_rpc_error().verror_code().v) { QString text = qs(error.c_rpc_error().verror_message()); @@ -30,3 +44,37 @@ RPCError::RPCError(const MTPrpcError &error) } } } + +RPCError::RPCError(const mtpBuffer &reply) : RPCError(ParseError(reply)) { +} + +int32 RPCError::code() const { + return _code; +} + +const QString &RPCError::type() const { + return _type; +} + +const QString &RPCError::description() const { + return _description; +} + +MTPrpcError RPCError::MTPLocal( + const QString &type, + const QString &description) { + return MTP_rpc_error( + MTP_int(0), + MTP_bytes( + ("CLIENT_" + + type + + (description.length() + ? (": " + description) + : QString())).toUtf8())); +} + +RPCError RPCError::Local( + const QString &type, + const QString &description) { + return RPCError(MTPLocal(type, description)); +} diff --git a/Telegram/SourceFiles/mtproto/mtproto_rpc_sender.h b/Telegram/SourceFiles/mtproto/mtproto_rpc_sender.h index 0ca1bcc27..d2723fd77 100644 --- a/Telegram/SourceFiles/mtproto/mtproto_rpc_sender.h +++ b/Telegram/SourceFiles/mtproto/mtproto_rpc_sender.h @@ -11,38 +11,27 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL class RPCError { public: - RPCError(const MTPrpcError &error); - - int32 code() const { - return _code; - } - - const QString &type() const { - return _type; - } - - const QString &description() const { - return _description; - } + explicit RPCError(const MTPrpcError &error); + explicit RPCError(const mtpBuffer &reply); enum { NoError, TimeoutError }; - static RPCError Local(const QString &type, const QString &description) { - return MTP_rpc_error( - MTP_int(0), - MTP_bytes( - ("CLIENT_" - + type - + (description.length() - ? (": " + description) - : QString())).toUtf8())); - } + [[nodiscard]] int32 code() const; + [[nodiscard]] const QString &type() const; + [[nodiscard]] const QString &description() const; + + [[nodiscard]] static RPCError Local( + const QString &type, + const QString &description); + [[nodiscard]] static MTPrpcError MTPLocal( + const QString &type, + const QString &description); private: - int32 _code; + int32 _code = 0; QString _type, _description; }; @@ -61,515 +50,18 @@ inline bool isDefaultHandledError(const RPCError &error) { return isTemporaryError(error); } -} // namespace MTP - -class RPCAbstractDoneHandler { // abstract done -public: - [[nodiscard]] virtual bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) = 0; - virtual ~RPCAbstractDoneHandler() { - } - -}; -using RPCDoneHandlerPtr = std::shared_ptr; - -class RPCAbstractFailHandler { // abstract fail -public: - virtual bool operator()(mtpRequestId requestId, const RPCError &e) = 0; - virtual ~RPCAbstractFailHandler() { - } -}; -using RPCFailHandlerPtr = std::shared_ptr; - -struct RPCResponseHandler { - RPCResponseHandler() = default; - RPCResponseHandler(RPCDoneHandlerPtr &&done, RPCFailHandlerPtr &&fail) - : onDone(std::move(done)) - , onFail(std::move(fail)) { - } - - RPCDoneHandlerPtr onDone; - RPCFailHandlerPtr onFail; - -}; - -class RPCDoneHandlerBare : public RPCAbstractDoneHandler { // done(from, end) - using CallbackType = bool (*)(const mtpPrime *, const mtpPrime *); - -public: - RPCDoneHandlerBare(CallbackType onDone) : _onDone(onDone) { - } - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - return (*_onDone)(from, end); - } - -private: - CallbackType _onDone; - -}; - -class RPCDoneHandlerBareReq : public RPCAbstractDoneHandler { // done(from, end, req_id) - using CallbackType = bool (*)(const mtpPrime *, const mtpPrime *, mtpRequestId); - -public: - RPCDoneHandlerBareReq(CallbackType onDone) : _onDone(onDone) { - } - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - return (*_onDone)(from, end, requestId); - } - -private: - CallbackType _onDone; - -}; - -template -class RPCDoneHandlerPlain : public RPCAbstractDoneHandler { // done(result) - using CallbackType = TReturn (*)(const TResponse &); - -public: - RPCDoneHandlerPlain(CallbackType onDone) : _onDone(onDone) { - } - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - auto response = TResponse(); - if (!response.read(from, end)) { - return false; - } - (*_onDone)(std::move(response)); - return true; - } - -private: - CallbackType _onDone; - -}; - -template -class RPCDoneHandlerReq : public RPCAbstractDoneHandler { // done(result, req_id) - using CallbackType = TReturn (*)(const TResponse &, mtpRequestId); - -public: - RPCDoneHandlerReq(CallbackType onDone) : _onDone(onDone) { - } - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - auto response = TResponse(); - if (!response.read(from, end)) { - return false; - } - (*_onDone)(std::move(response), requestId); - return true; - } - -private: - CallbackType _onDone; - -}; - -template -class RPCDoneHandlerNo : public RPCAbstractDoneHandler { // done() - using CallbackType = TReturn (*)(); - -public: - RPCDoneHandlerNo(CallbackType onDone) : _onDone(onDone) { - } - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - (*_onDone)(); - return true; - } - -private: - CallbackType _onDone; - -}; - -template -class RPCDoneHandlerNoReq : public RPCAbstractDoneHandler { // done(req_id) - using CallbackType = TReturn (*)(mtpRequestId); - -public: - RPCDoneHandlerNoReq(CallbackType onDone) : _onDone(onDone) { - } - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - (*_onDone)(requestId); - return true; - } - -private: - CallbackType _onDone; - -}; - -class RPCFailHandlerPlain : public RPCAbstractFailHandler { // fail(error) - using CallbackType = bool (*)(const RPCError &); - -public: - RPCFailHandlerPlain(CallbackType onFail) : _onFail(onFail) { - } - bool operator()(mtpRequestId requestId, const RPCError &e) override { - return (*_onFail)(e); - } - -private: - CallbackType _onFail; - -}; - -class RPCFailHandlerReq : public RPCAbstractFailHandler { // fail(error, req_id) - using CallbackType = bool (*)(const RPCError &, mtpRequestId); - -public: - RPCFailHandlerReq(CallbackType onFail) : _onFail(onFail) { - } - bool operator()(mtpRequestId requestId, const RPCError &e) override { - return (*_onFail)(e, requestId); - } - -private: - CallbackType _onFail; - -}; - -class RPCFailHandlerNo : public RPCAbstractFailHandler { // fail() - using CallbackType = bool (*)(); - -public: - RPCFailHandlerNo(CallbackType onFail) : _onFail(onFail) { - } - bool operator()(mtpRequestId requestId, const RPCError &e) override { - return (*_onFail)(); - } - -private: - CallbackType _onFail; - -}; - -class RPCFailHandlerNoReq : public RPCAbstractFailHandler { // fail(req_id) - using CallbackType = bool (*)(mtpRequestId); - -public: - RPCFailHandlerNoReq(CallbackType onFail) : _onFail(onFail) { - } - bool operator()(mtpRequestId requestId, const RPCError &e) override { - return (*_onFail)(requestId); - } - -private: - CallbackType _onFail; - -}; - -struct RPCCallbackClear { - RPCCallbackClear(mtpRequestId id, int32 code = RPCError::NoError) - : requestId(id) - , errorCode(code) { - } - +struct Response { + mtpBuffer reply; + mtpMsgId outerMsgId = 0; mtpRequestId requestId = 0; - int32 errorCode = 0; - }; -inline RPCDoneHandlerPtr rpcDone(bool (*onDone)(const mtpPrime *, const mtpPrime *)) { // done(from, end) - return RPCDoneHandlerPtr(new RPCDoneHandlerBare(onDone)); -} - -inline RPCDoneHandlerPtr rpcDone(bool (*onDone)(const mtpPrime *, const mtpPrime *, mtpRequestId)) { // done(from, end, req_id) - return RPCDoneHandlerPtr(new RPCDoneHandlerBareReq(onDone)); -} - -template -inline RPCDoneHandlerPtr rpcDone(TReturn (*onDone)(const TResponse &)) { // done(result) - return RPCDoneHandlerPtr(new RPCDoneHandlerPlain(onDone)); -} - -template -inline RPCDoneHandlerPtr rpcDone(TReturn (*onDone)(const TResponse &, mtpRequestId)) { // done(result, req_id) - return RPCDoneHandlerPtr(new RPCDoneHandlerReq(onDone)); -} - -template -inline RPCDoneHandlerPtr rpcDone(TReturn (*onDone)()) { // done() - return RPCDoneHandlerPtr(new RPCDoneHandlerNo(onDone)); -} - -template -inline RPCDoneHandlerPtr rpcDone(TReturn (*onDone)(mtpRequestId)) { // done(req_id) - return RPCDoneHandlerPtr(new RPCDoneHandlerNoReq(onDone)); -} - -inline RPCFailHandlerPtr rpcFail(bool (*onFail)(const RPCError &)) { // fail(error) - return RPCFailHandlerPtr(new RPCFailHandlerPlain(onFail)); -} - -inline RPCFailHandlerPtr rpcFail(bool (*onFail)(const RPCError &, mtpRequestId)) { // fail(error, req_id) - return RPCFailHandlerPtr(new RPCFailHandlerReq(onFail)); -} - -inline RPCFailHandlerPtr rpcFail(bool (*onFail)()) { // fail() - return RPCFailHandlerPtr(new RPCFailHandlerNo(onFail)); -} - -inline RPCFailHandlerPtr rpcFail(bool (*onFail)(mtpRequestId)) { // fail(req_id) - return RPCFailHandlerPtr(new RPCFailHandlerNoReq(onFail)); -} - -using MTPStateChangedHandler = void (*)(int32 dcId, int32 state); -using MTPSessionResetHandler = void (*)(int32 dcId); - -template -class RPCHandlerImplementation : public Base { -protected: - using Lambda = FnMut; - using Parent = RPCHandlerImplementation; - -public: - RPCHandlerImplementation(Lambda handler) : _handler(std::move(handler)) { - } - -protected: - Lambda _handler; +using DoneHandler = FnMut; +using FailHandler = Fn; +struct ResponseHandler { + DoneHandler done; + FailHandler fail; }; -template -using RPCDoneHandlerImplementation = RPCHandlerImplementation; - -class RPCDoneHandlerImplementationBare : public RPCDoneHandlerImplementation { // done(from, end) -public: - using RPCDoneHandlerImplementation::Parent::Parent; - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - return this->_handler ? this->_handler(from, end) : true; - } - -}; - -class RPCDoneHandlerImplementationBareReq : public RPCDoneHandlerImplementation { // done(from, end, req_id) -public: - using RPCDoneHandlerImplementation::Parent::Parent; - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - return this->_handler ? this->_handler(from, end, requestId) : true; - } - -}; - -template -class RPCDoneHandlerImplementationPlain : public RPCDoneHandlerImplementation { // done(result) -public: - using RPCDoneHandlerImplementation::Parent::Parent; - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - auto response = TResponse(); - if (!response.read(from, end)) { - return false; - } - if (this->_handler) { - this->_handler(std::move(response)); - } - return true; - } - -}; - -template -class RPCDoneHandlerImplementationReq : public RPCDoneHandlerImplementation { // done(result, req_id) -public: - using RPCDoneHandlerImplementation::Parent::Parent; - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - auto response = TResponse(); - if (!response.read(from, end)) { - return false; - } - if (this->_handler) { - this->_handler(std::move(response), requestId); - } - return true; - } - -}; - -template -class RPCDoneHandlerImplementationNo : public RPCDoneHandlerImplementation { // done() -public: - using RPCDoneHandlerImplementation::Parent::Parent; - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - if (this->_handler) { - this->_handler(); - } - return true; - } - -}; - -template -class RPCDoneHandlerImplementationNoReq : public RPCDoneHandlerImplementation { // done(req_id) -public: - using RPCDoneHandlerImplementation::Parent::Parent; - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - if (this->_handler) { - this->_handler(requestId); - } - return true; - } - -}; - -template -constexpr bool rpcDone_canCallBare_v = rpl::details::is_callable_plain_v< - Lambda, const mtpPrime*, const mtpPrime*>; - -template -constexpr bool rpcDone_canCallBareReq_v = rpl::details::is_callable_plain_v< - Lambda, const mtpPrime*, const mtpPrime*, mtpRequestId>; - -template -constexpr bool rpcDone_canCallNo_v = rpl::details::is_callable_plain_v< - Lambda>; - -template -constexpr bool rpcDone_canCallNoReq_v = rpl::details::is_callable_plain_v< - Lambda, mtpRequestId>; - -template -struct rpcDone_canCallPlain : std::false_type { -}; - -template -struct rpcDone_canCallPlain : std::true_type { - using Arg = T; -}; - -template -struct rpcDone_canCallPlain - : rpcDone_canCallPlain { -}; - -template -constexpr bool rpcDone_canCallPlain_v = rpcDone_canCallPlain::value; - -template -struct rpcDone_canCallReq : std::false_type { -}; - -template -struct rpcDone_canCallReq : std::true_type { - using Arg = T; -}; - -template -struct rpcDone_canCallReq - : rpcDone_canCallReq { -}; - -template -constexpr bool rpcDone_canCallReq_v = rpcDone_canCallReq::value; - -template -struct rpcDone_returnType; - -template -struct rpcDone_returnType { - using type = Return; -}; - -template -struct rpcDone_returnType { - using type = Return; -}; - -template -using rpcDone_returnType_t = typename rpcDone_returnType::type; - -template < - typename Lambda, - typename Function = crl::deduced_call_type> -RPCDoneHandlerPtr rpcDone(Lambda lambda) { - using R = rpcDone_returnType_t; - if constexpr (rpcDone_canCallBare_v) { - return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationBare(std::move(lambda))); - } else if constexpr (rpcDone_canCallBareReq_v) { - return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationBareReq(std::move(lambda))); - } else if constexpr (rpcDone_canCallNo_v) { - return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationNo(std::move(lambda))); - } else if constexpr (rpcDone_canCallNoReq_v) { - return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationNoReq(std::move(lambda))); - } else if constexpr (rpcDone_canCallPlain_v) { - using T = typename rpcDone_canCallPlain::Arg; - return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationPlain(std::move(lambda))); - } else if constexpr (rpcDone_canCallReq_v) { - using T = typename rpcDone_canCallReq::Arg; - return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationReq(std::move(lambda))); - } else { - static_assert(false_t(lambda), "Unknown method."); - } -} - -template -using RPCFailHandlerImplementation = RPCHandlerImplementation; - -class RPCFailHandlerImplementationPlain : public RPCFailHandlerImplementation { // fail(error) -public: - using Parent::Parent; - bool operator()(mtpRequestId requestId, const RPCError &error) override { - return _handler ? _handler(error) : true; - } - -}; - -class RPCFailHandlerImplementationReq : public RPCFailHandlerImplementation { // fail(error, req_id) -public: - using Parent::Parent; - bool operator()(mtpRequestId requestId, const RPCError &error) override { - return this->_handler ? this->_handler(error, requestId) : true; - } - -}; - -class RPCFailHandlerImplementationNo : public RPCFailHandlerImplementation { // fail() -public: - using Parent::Parent; - bool operator()(mtpRequestId requestId, const RPCError &error) override { - return this->_handler ? this->_handler() : true; - } - -}; - -class RPCFailHandlerImplementationNoReq : public RPCFailHandlerImplementation { // fail(req_id) -public: - using Parent::Parent; - bool operator()(mtpRequestId requestId, const RPCError &error) override { - return this->_handler ? this->_handler(requestId) : true; - } - -}; - -template -constexpr bool rpcFail_canCallNo_v = rpl::details::is_callable_plain_v< - Lambda>; - -template -constexpr bool rpcFail_canCallNoReq_v = rpl::details::is_callable_plain_v< - Lambda, mtpRequestId>; - -template -constexpr bool rpcFail_canCallPlain_v = rpl::details::is_callable_plain_v< - Lambda, const RPCError&>; - -template -constexpr bool rpcFail_canCallReq_v = rpl::details::is_callable_plain_v< - Lambda, const RPCError&, mtpRequestId>; - -template < - typename Lambda, - typename Function = crl::deduced_call_type> -RPCFailHandlerPtr rpcFail(Lambda lambda) { - if constexpr (rpcFail_canCallNo_v) { - return RPCFailHandlerPtr(new RPCFailHandlerImplementationNo(std::move(lambda))); - } else if constexpr (rpcFail_canCallNoReq_v) { - return RPCFailHandlerPtr(new RPCFailHandlerImplementationNoReq(std::move(lambda))); - } else if constexpr (rpcFail_canCallPlain_v) { - return RPCFailHandlerPtr(new RPCFailHandlerImplementationPlain(std::move(lambda))); - } else if constexpr (rpcFail_canCallReq_v) { - return RPCFailHandlerPtr(new RPCFailHandlerImplementationReq(std::move(lambda))); - } else { - static_assert(false_t(lambda), "Unknown method."); - } -} +} // namespace MTP diff --git a/Telegram/SourceFiles/mtproto/sender.h b/Telegram/SourceFiles/mtproto/sender.h index 387a1e983..2af657032 100644 --- a/Telegram/SourceFiles/mtproto/sender.h +++ b/Telegram/SourceFiles/mtproto/sender.h @@ -22,111 +22,108 @@ class Sender { RequestBuilder &operator=(RequestBuilder &&other) = delete; protected: - using FailPlainHandler = FnMut; - using FailRequestIdHandler = FnMut; enum class FailSkipPolicy { Simple, HandleFlood, HandleAll, }; - template - struct DonePlainPolicy { - using Callback = FnMut; - static void handle(Callback &&handler, mtpRequestId requestId, Response &&result) { - handler(result); - } + using FailPlainHandler = Fn; + using FailErrorHandler = Fn; + using FailRequestIdHandler = Fn; + using FailFullHandler = Fn; - }; - template - struct DoneRequestIdPolicy { - using Callback = FnMut; - static void handle(Callback &&handler, mtpRequestId requestId, Response &&result) { - handler(result, requestId); - } + template + static constexpr bool IsCallable + = rpl::details::is_callable_plain_v; - }; - template typename PolicyTemplate> - class DoneHandler : public RPCAbstractDoneHandler { - using Policy = PolicyTemplate; - using Callback = typename Policy::Callback; + template + [[nodiscard]] DoneHandler MakeDoneHandler( + not_null sender, + Handler &&handler) { + return [sender, handler = std::forward(handler)]( + const Response &response) mutable { + auto onstack = std::move(handler); + sender->senderRequestHandled(response.requestId); - public: - DoneHandler(not_null sender, Callback handler) : _sender(sender), _handler(std::move(handler)) { - } - - bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - auto handler = std::move(_handler); - _sender->senderRequestHandled(requestId); - - auto result = Response(); - if (!result.read(from, end)) { + auto result = Result(); + auto from = response.reply.constData(); + if (!result.read(from, from + response.reply.size())) { return false; - } - if (handler) { - Policy::handle(std::move(handler), requestId, std::move(result)); + } else if (!onstack) { + return true; + } else if constexpr (IsCallable< + Handler, + const Result&, + const Response&>) { + onstack(result, response); + } else if constexpr (IsCallable< + Handler, + const Result&, + mtpRequestId>) { + onstack(result, response.requestId); + } else if constexpr (IsCallable< + Handler, + const Result&>) { + onstack(result); + } else if constexpr (IsCallable) { + onstack(); + } else { + static_assert(false_t(Handler{}), "Bad done handler."); } return true; - } + }; + } - private: - not_null _sender; - Callback _handler; - - }; - - struct FailPlainPolicy { - using Callback = FnMut; - static void handle(Callback &&handler, mtpRequestId requestId, const RPCError &error) { - handler(error); - } - - }; - struct FailRequestIdPolicy { - using Callback = FnMut; - static void handle(Callback &&handler, mtpRequestId requestId, const RPCError &error) { - handler(error, requestId); - } - - }; - template - class FailHandler : public RPCAbstractFailHandler { - using Callback = typename Policy::Callback; - - public: - FailHandler(not_null sender, Callback handler, FailSkipPolicy skipPolicy) - : _sender(sender) - , _handler(std::move(handler)) - , _skipPolicy(skipPolicy) { - } - - bool operator()(mtpRequestId requestId, const RPCError &error) override { - if (_skipPolicy == FailSkipPolicy::Simple) { + template + [[nodiscard]] FailHandler MakeFailHandler( + not_null sender, + Handler &&handler, + FailSkipPolicy skipPolicy) { + return [ + sender, + handler = std::forward(handler), + skipPolicy + ](const RPCError &error, const Response &response) { + if (skipPolicy == FailSkipPolicy::Simple) { if (isDefaultHandledError(error)) { return false; } - } else if (_skipPolicy == FailSkipPolicy::HandleFlood) { + } else if (skipPolicy == FailSkipPolicy::HandleFlood) { if (isDefaultHandledError(error) && !isFloodError(error)) { return false; } } - auto handler = std::move(_handler); - _sender->senderRequestHandled(requestId); + auto onstack = handler; + sender->senderRequestHandled(response.requestId); - if (handler) { - Policy::handle(std::move(handler), requestId, error); + if (!onstack) { + return true; + } else if constexpr (IsCallable< + Handler, + const RPCError&, + const Response&>) { + onstack(error, response); + } else if constexpr (IsCallable< + Handler, + const RPCError&, + mtpRequestId>) { + onstack(error, response.requestId); + } else if constexpr (IsCallable< + Handler, + const RPCError&>) { + onstack(error); + } else if constexpr (IsCallable) { + onstack(); + } else { + static_assert(false_t(Handler{}), "Bad fail handler."); } return true; - } + }; + } - private: - not_null _sender; - Callback _handler; - FailSkipPolicy _skipPolicy = FailSkipPolicy::Simple; - - }; - - explicit RequestBuilder(not_null sender) noexcept : _sender(sender) { + explicit RequestBuilder(not_null sender) noexcept + : _sender(sender) { } RequestBuilder(RequestBuilder &&other) = default; @@ -136,14 +133,12 @@ class Sender { void setCanWait(crl::time ms) noexcept { _canWait = ms; } - void setDoneHandler(RPCDoneHandlerPtr &&handler) noexcept { + void setDoneHandler(DoneHandler &&handler) noexcept { _done = std::move(handler); } - void setFailHandler(FailPlainHandler &&handler) noexcept { - _fail = std::move(handler); - } - void setFailHandler(FailRequestIdHandler &&handler) noexcept { - _fail = std::move(handler); + template + void setFailHandler(Handler &&handler) noexcept { + _fail = std::forward(handler); } void setFailSkipPolicy(FailSkipPolicy policy) noexcept { _failSkipPolicy = policy; @@ -158,18 +153,12 @@ class Sender { crl::time takeCanWait() const noexcept { return _canWait; } - RPCDoneHandlerPtr takeOnDone() noexcept { + DoneHandler takeOnDone() noexcept { return std::move(_done); } - RPCFailHandlerPtr takeOnFail() { - return v::match(_fail, [&](FailPlainHandler &value) - -> RPCFailHandlerPtr { - return std::make_shared>( - _sender, - std::move(value), - _failSkipPolicy); - }, [&](FailRequestIdHandler &value) -> RPCFailHandlerPtr { - return std::make_shared>( + FailHandler takeOnFail() { + return v::match(_fail, [&](auto &value) { + return MakeFailHandler( _sender, std::move(value), _failSkipPolicy); @@ -190,8 +179,12 @@ class Sender { not_null _sender; ShiftedDcId _dcId = 0; crl::time _canWait = 0; - RPCDoneHandlerPtr _done; - std::variant _fail; + DoneHandler _done; + std::variant< + FailPlainHandler, + FailErrorHandler, + FailRequestIdHandler, + FailFullHandler> _fail; FailSkipPolicy _failSkipPolicy = FailSkipPolicy::Simple; mtpRequestId _afterRequestId = 0; @@ -210,7 +203,9 @@ public: class SpecificRequestBuilder : public RequestBuilder { private: friend class Sender; - SpecificRequestBuilder(not_null sender, Request &&request) noexcept : RequestBuilder(sender), _request(std::move(request)) { + SpecificRequestBuilder(not_null sender, Request &&request) noexcept + : RequestBuilder(sender) + , _request(std::move(request)) { } SpecificRequestBuilder(SpecificRequestBuilder &&other) = default; @@ -223,22 +218,63 @@ public: setCanWait(ms); return *this; } - [[nodiscard]] SpecificRequestBuilder &done(FnMut callback) { - setDoneHandler(std::make_shared>(sender(), std::move(callback))); + + using Result = typename Request::ResponseType; + [[nodiscard]] SpecificRequestBuilder &done( + FnMut callback) { + setDoneHandler( + MakeDoneHandler(sender(), std::move(callback))); return *this; } - [[nodiscard]] SpecificRequestBuilder &done(FnMut callback) { - setDoneHandler(std::make_shared>(sender(), std::move(callback))); + [[nodiscard]] SpecificRequestBuilder &done( + FnMut callback) { + setDoneHandler( + MakeDoneHandler(sender(), std::move(callback))); return *this; } - [[nodiscard]] SpecificRequestBuilder &fail(FnMut callback) noexcept { + [[nodiscard]] SpecificRequestBuilder &done( + FnMut callback) { + setDoneHandler( + MakeDoneHandler(sender(), std::move(callback))); + return *this; + } + [[nodiscard]] SpecificRequestBuilder &done( + FnMut callback) { + setDoneHandler( + MakeDoneHandler(sender(), std::move(callback))); + return *this; + } + + [[nodiscard]] SpecificRequestBuilder &fail( + Fn callback) noexcept { setFailHandler(std::move(callback)); return *this; } - [[nodiscard]] SpecificRequestBuilder &fail(FnMut callback) noexcept { + [[nodiscard]] SpecificRequestBuilder &fail( + Fn callback) noexcept { setFailHandler(std::move(callback)); return *this; } + [[nodiscard]] SpecificRequestBuilder &fail( + Fn callback) noexcept { + setFailHandler(std::move(callback)); + return *this; + } + [[nodiscard]] SpecificRequestBuilder &fail( + Fn callback) noexcept { + setFailHandler(std::move(callback)); + return *this; + } + [[nodiscard]] SpecificRequestBuilder &handleFloodErrors() noexcept { setFailSkipPolicy(FailSkipPolicy::HandleFlood); return *this; diff --git a/Telegram/SourceFiles/mtproto/session.cpp b/Telegram/SourceFiles/mtproto/session.cpp index b4f264662..c0c3cda91 100644 --- a/Telegram/SourceFiles/mtproto/session.cpp +++ b/Telegram/SourceFiles/mtproto/session.cpp @@ -225,13 +225,6 @@ void Session::start() { _shiftedDcId); } -bool Session::rpcErrorOccured( - mtpRequestId requestId, - const RPCFailHandlerPtr &onFail, - const RPCError &error) { // return true if need to clean request data - return _instance->rpcErrorOccured(requestId, onFail, error); -} - void Session::restart() { if (_killed) { DEBUG_LOG(("Session Error: can't restart a killed session")); @@ -560,25 +553,17 @@ void Session::tryToReceive() { } while (true) { auto lock = QWriteLocker(_data->haveReceivedMutex()); - const auto responses = base::take(_data->haveReceivedResponses()); - const auto updates = base::take(_data->haveReceivedUpdates()); + const auto messages = base::take(_data->haveReceivedMessages()); lock.unlock(); - if (responses.empty() && updates.empty()) { + if (messages.empty()) { break; } - for (const auto &[requestId, response] : responses) { - _instance->execCallback( - requestId, - response.constData(), - response.constData() + response.size()); - } - - // Call globalCallback only in main session. - if (_shiftedDcId == BareDcId(_shiftedDcId)) { - for (const auto &update : updates) { - _instance->globalCallback( - update.constData(), - update.constData() + update.size()); + for (const auto &message : messages) { + if (message.requestId) { + _instance->processCallback(message); + } else if (_shiftedDcId == BareDcId(_shiftedDcId)) { + // Process updates only in main session. + _instance->processUpdate(message); } } } diff --git a/Telegram/SourceFiles/mtproto/session.h b/Telegram/SourceFiles/mtproto/session.h index dc8b152c6..95ac0da70 100644 --- a/Telegram/SourceFiles/mtproto/session.h +++ b/Telegram/SourceFiles/mtproto/session.h @@ -84,11 +84,8 @@ public: base::flat_map &haveSentMap() { return _haveSent; } - base::flat_map &haveReceivedResponses() { - return _receivedResponses; - } - std::vector &haveReceivedUpdates() { - return _receivedUpdates; + std::vector &haveReceivedMessages() { + return _receivedMessages; } // SessionPrivate -> Session interface. @@ -128,8 +125,7 @@ private: base::flat_map _haveSent; // map of msg_id -> request, that was sent QReadWriteLock _haveSentLock; - base::flat_map _receivedResponses; // map of request_id -> response that should be processed in the main thread - std::vector _receivedUpdates; // list of updates that should be processed in the main thread + std::vector _receivedMessages; // list of responses / updates that should be processed in the main thread QReadWriteLock _haveReceivedLock; }; @@ -192,11 +188,6 @@ private: void killConnection(); - bool rpcErrorOccured( - mtpRequestId requestId, - const RPCFailHandlerPtr &onFail, - const RPCError &err); - [[nodiscard]] bool releaseGenericKeyCreationOnDone( const AuthKeyPtr &temporaryKey, const AuthKeyPtr &persistentKeyUsedForBind); diff --git a/Telegram/SourceFiles/mtproto/session_private.cpp b/Telegram/SourceFiles/mtproto/session_private.cpp index ea5b3b30f..ce37352c1 100644 --- a/Telegram/SourceFiles/mtproto/session_private.cpp +++ b/Telegram/SourceFiles/mtproto/session_private.cpp @@ -1363,7 +1363,12 @@ void SessionPrivate::handleReceived() { ).arg(_encryptionKey->keyId())); if (_receivedMessageIds.registerMsgId(msgId, needAck)) { - res = handleOneReceived(from, end, msgId, serverTime, serverSalt, badTime); + res = handleOneReceived(from, end, msgId, { + .outerMsgId = msgId, + .serverSalt = serverSalt, + .serverTime = serverTime, + .badTime = badTime, + }); } _receivedMessageIds.shrink(); @@ -1374,12 +1379,11 @@ void SessionPrivate::handleReceived() { } auto lock = QReadLocker(_sessionData->haveReceivedMutex()); - const auto tryToReceive = !_sessionData->haveReceivedResponses().empty() - || !_sessionData->haveReceivedUpdates().empty(); + const auto tryToReceive = !_sessionData->haveReceivedMessages().empty(); lock.unlock(); if (tryToReceive) { - DEBUG_LOG(("MTP Info: queueTryToReceive() - need to parse in another thread, %1 responses, %2 updates.").arg(_sessionData->haveReceivedResponses().size()).arg(_sessionData->haveReceivedUpdates().size())); + DEBUG_LOG(("MTP Info: queueTryToReceive() - need to parse in another thread, %1 messages.").arg(_sessionData->haveReceivedMessages().size())); _sessionData->queueTryToReceive(); } @@ -1410,9 +1414,7 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( const mtpPrime *from, const mtpPrime *end, uint64 msgId, - int32 serverTime, - uint64 serverSalt, - bool badTime) { + OuterInfo info) { Expects(from < end); switch (mtpTypeId(*from)) { @@ -1423,7 +1425,7 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( if (response.empty()) { return HandleResult::RestartConnection; } - return handleOneReceived(response.data(), response.data() + response.size(), msgId, serverTime, serverSalt, badTime); + return handleOneReceived(response.data(), response.data() + response.size(), msgId, info); } case mtpc_msg_container: { @@ -1475,8 +1477,8 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( auto res = HandleResult::Success; // if no need to handle, then succeed if (_receivedMessageIds.registerMsgId(inMsgId.v, needAck)) { - res = handleOneReceived(from, otherEnd, inMsgId.v, serverTime, serverSalt, badTime); - badTime = false; + res = handleOneReceived(from, otherEnd, inMsgId.v, info); + info.badTime = false; } if (res != HandleResult::Success) { return res; @@ -1495,15 +1497,15 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( DEBUG_LOG(("Message Info: acks received, ids: %1" ).arg(LogIdsVector(ids))); if (ids.isEmpty()) { - return badTime ? HandleResult::Ignored : HandleResult::Success; + return info.badTime ? HandleResult::Ignored : HandleResult::Success; } - if (badTime) { - if (!requestsFixTimeSalt(ids, serverTime, serverSalt)) { + if (info.badTime) { + if (!requestsFixTimeSalt(ids, info)) { return HandleResult::Ignored; } } else { - correctUnixtimeByFastRequest(ids, serverTime); + correctUnixtimeByFastRequest(ids, info.serverTime); } requestsAcked(ids); } return HandleResult::Success; @@ -1546,28 +1548,28 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( if (!wasSent(resendId)) { DEBUG_LOG(("Message Error: " "such message was not sent recently %1").arg(resendId)); - return badTime + return info.badTime ? HandleResult::Ignored : HandleResult::Success; } if (needResend) { // bad msg_id or bad container - if (serverSalt) { - _sessionSalt = serverSalt; + if (info.serverSalt) { + _sessionSalt = info.serverSalt; } - correctUnixtimeWithBadLocal(serverTime); + correctUnixtimeWithBadLocal(info.serverTime); - DEBUG_LOG(("Message Info: unixtime updated, now %1, resending in container...").arg(serverTime)); + DEBUG_LOG(("Message Info: unixtime updated, now %1, resending in container...").arg(info.serverTime)); resend(resendId, 0, true); } else { // must create new session, because msg_id and msg_seqno are inconsistent - if (badTime) { - if (serverSalt) { - _sessionSalt = serverSalt; + if (info.badTime) { + if (info.serverSalt) { + _sessionSalt = info.serverSalt; } - correctUnixtimeWithBadLocal(serverTime); - badTime = false; + correctUnixtimeWithBadLocal(info.serverTime); + info.badTime = false; } LOG(("Message Info: bad message notification received, msgId %1, error_code %2").arg(data.vbad_msg_id().v).arg(errorCode)); return HandleResult::ResetSession; @@ -1582,20 +1584,24 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( ).arg(badMsgId ).arg(errorCode ).arg(requestId)); - auto response = mtpBuffer(); + auto reply = mtpBuffer(); MTPRpcError(MTP_rpc_error( MTP_int(500), MTP_string("PROTOCOL_ERROR") - )).write(response); + )).write(reply); // Save rpc_error for processing in the main thread. QWriteLocker locker(_sessionData->haveReceivedMutex()); - _sessionData->haveReceivedResponses().emplace(requestId, response); + _sessionData->haveReceivedMessages().push_back({ + .reply = std::move(reply), + .outerMsgId = info.outerMsgId, + .requestId = requestId, + }); } else { DEBUG_LOG(("Message Error: " "such message was not sent recently %1").arg(badMsgId)); } - return badTime + return info.badTime ? HandleResult::Ignored : HandleResult::Success; } @@ -1612,19 +1618,19 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( const auto resendId = data.vbad_msg_id().v; if (!wasSent(resendId)) { DEBUG_LOG(("Message Error: such message was not sent recently %1").arg(resendId)); - return (badTime ? HandleResult::Ignored : HandleResult::Success); + return (info.badTime ? HandleResult::Ignored : HandleResult::Success); } _sessionSalt = data.vnew_server_salt().v; - correctUnixtimeWithBadLocal(serverTime); + correctUnixtimeWithBadLocal(info.serverTime); if (setState(ConnectedState, ConnectingState)) { resendAll(); } - badTime = false; + info.badTime = false; - DEBUG_LOG(("Message Info: unixtime updated, now %1, server_salt updated, now %2, resending...").arg(serverTime).arg(serverSalt)); + DEBUG_LOG(("Message Info: unixtime updated, now %1, server_salt updated, now %2, resending...").arg(info.serverTime).arg(info.serverSalt)); resend(resendId); } return HandleResult::Success; @@ -1642,17 +1648,19 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( const auto i = _stateAndResendRequests.find(reqMsgId); if (i == _stateAndResendRequests.end()) { DEBUG_LOG(("Message Error: such message was not sent recently %1").arg(reqMsgId)); - return (badTime ? HandleResult::Ignored : HandleResult::Success); + return info.badTime + ? HandleResult::Ignored + : HandleResult::Success; } - if (badTime) { - if (serverSalt) { - _sessionSalt = serverSalt; // requestsFixTimeSalt with no lookup + if (info.badTime) { + if (info.serverSalt) { + _sessionSalt = info.serverSalt; // requestsFixTimeSalt with no lookup } - correctUnixtimeWithBadLocal(serverTime); + correctUnixtimeWithBadLocal(info.serverTime); - DEBUG_LOG(("Message Info: unixtime updated from mtpc_msgs_state_info, now %1").arg(serverTime)); + DEBUG_LOG(("Message Info: unixtime updated from mtpc_msgs_state_info, now %1").arg(info.serverTime)); - badTime = false; + info.badTime = false; } const auto originalRequest = i->second; Assert(originalRequest->size() > 8); @@ -1680,7 +1688,7 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( } return HandleResult::Success; case mtpc_msgs_all_info: { - if (badTime) { + if (info.badTime) { DEBUG_LOG(("Message Info: skipping with bad time...")); return HandleResult::Ignored; } @@ -1707,9 +1715,9 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( DEBUG_LOG(("Message Info: msg detailed info, sent msgId %1, answerId %2, status %3, bytes %4").arg(data.vmsg_id().v).arg(data.vanswer_msg_id().v).arg(data.vstatus().v).arg(data.vbytes().v)); QVector ids(1, data.vmsg_id()); - if (badTime) { - if (requestsFixTimeSalt(ids, serverTime, serverSalt)) { - badTime = false; + if (info.badTime) { + if (requestsFixTimeSalt(ids, info)) { + info.badTime = false; } else { DEBUG_LOG(("Message Info: error, such message was not sent recently %1").arg(data.vmsg_id().v)); return HandleResult::Ignored; @@ -1727,7 +1735,7 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( } return HandleResult::Success; case mtpc_msg_new_detailed_info: { - if (badTime) { + if (info.badTime) { DEBUG_LOG(("Message Info: skipping msg_new_detailed_info with bad time...")); return HandleResult::Ignored; } @@ -1763,9 +1771,9 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( DEBUG_LOG(("RPC Info: response received for %1, queueing...").arg(requestMsgId)); QVector ids(1, reqMsgId); - if (badTime) { - if (requestsFixTimeSalt(ids, serverTime, serverSalt)) { - badTime = false; + if (info.badTime) { + if (requestsFixTimeSalt(ids, info)) { + info.badTime = false; } else { DEBUG_LOG(("Message Info: error, such message was not sent recently %1").arg(requestMsgId)); return HandleResult::Ignored; @@ -1804,7 +1812,11 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( if (requestId && requestId != mtpRequestId(0xFFFFFFFF)) { // Save rpc_result for processing in the main thread. QWriteLocker locker(_sessionData->haveReceivedMutex()); - _sessionData->haveReceivedResponses().emplace(requestId, response); + _sessionData->haveReceivedMessages().push_back({ + .reply = std::move(response), + .outerMsgId = info.outerMsgId, + .requestId = requestId, + }); } else { DEBUG_LOG(("RPC Info: requestId not found for msgId %1").arg(requestMsgId)); } @@ -1818,9 +1830,9 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( } const auto &data(msg.c_new_session_created()); - if (badTime) { - if (requestsFixTimeSalt(QVector(1, data.vfirst_msg_id()), serverTime, serverSalt)) { - badTime = false; + if (info.badTime) { + if (requestsFixTimeSalt(QVector(1, data.vfirst_msg_id()), info)) { + info.badTime = false; } else { DEBUG_LOG(("Message Info: error, such message was not sent recently %1").arg(data.vfirst_msg_id().v)); return HandleResult::Ignored; @@ -1853,7 +1865,10 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( // Notify main process about new session - need to get difference. QWriteLocker locker(_sessionData->haveReceivedMutex()); - _sessionData->haveReceivedUpdates().push_back(mtpBuffer(update)); + _sessionData->haveReceivedMessages().push_back({ + .reply = update, + .outerMsgId = info.outerMsgId, + }); } return HandleResult::Success; case mtpc_pong: { @@ -1875,9 +1890,9 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( } QVector ids(1, data.vmsg_id()); - if (badTime) { - if (requestsFixTimeSalt(ids, serverTime, serverSalt)) { - badTime = false; + if (info.badTime) { + if (requestsFixTimeSalt(ids, info)) { + info.badTime = false; } else { return HandleResult::Ignored; } @@ -1887,7 +1902,7 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( } - if (badTime) { + if (info.badTime) { DEBUG_LOG(("Message Error: bad time in updates cons, must create new session")); return HandleResult::ResetSession; } @@ -1900,7 +1915,10 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived( // Notify main process about the new updates. QWriteLocker locker(_sessionData->haveReceivedMutex()); - _sessionData->haveReceivedUpdates().push_back(mtpBuffer(update)); + _sessionData->haveReceivedMessages().push_back({ + .reply = update, + .outerMsgId = info.outerMsgId, + }); } else { LOG(("Message Error: unexpected updates in dcType: %1" ).arg(static_cast(_currentDcType))); @@ -1991,14 +2009,14 @@ mtpBuffer SessionPrivate::ungzip(const mtpPrime *from, const mtpPrime *end) cons return result; } -bool SessionPrivate::requestsFixTimeSalt(const QVector &ids, int32 serverTime, uint64 serverSalt) { +bool SessionPrivate::requestsFixTimeSalt(const QVector &ids, const OuterInfo &info) { for (const auto &id : ids) { if (wasSent(id.v)) { // Found such msg_id in recent acked or in recent sent requests. - if (serverSalt) { - _sessionSalt = serverSalt; + if (info.serverSalt) { + _sessionSalt = info.serverSalt; } - correctUnixtimeWithBadLocal(serverTime); + correctUnixtimeWithBadLocal(info.serverTime); return true; } } @@ -2063,7 +2081,7 @@ void SessionPrivate::requestsAcked(const QVector &ids, bool byResponse) if (const auto i = haveSent.find(msgId); i != end(haveSent)) { const auto requestId = i->second->requestId; - if (!byResponse && _instance->hasCallbacks(requestId)) { + if (!byResponse && _instance->hasCallback(requestId)) { DEBUG_LOG(("Message Info: ignoring ACK for msgId %1 because request %2 requires a response").arg(msgId).arg(requestId)); continue; } @@ -2076,7 +2094,7 @@ void SessionPrivate::requestsAcked(const QVector &ids, bool byResponse) if (const auto i = _resendingIds.find(msgId); i != end(_resendingIds)) { const auto requestId = i->second; - if (!byResponse && _instance->hasCallbacks(requestId)) { + if (!byResponse && _instance->hasCallback(requestId)) { DEBUG_LOG(("Message Info: ignoring ACK for msgId %1 because request %2 requires a response").arg(msgId).arg(requestId)); continue; } diff --git a/Telegram/SourceFiles/mtproto/session_private.h b/Telegram/SourceFiles/mtproto/session_private.h index cf7c68ab4..2c75c4511 100644 --- a/Telegram/SourceFiles/mtproto/session_private.h +++ b/Telegram/SourceFiles/mtproto/session_private.h @@ -120,7 +120,17 @@ private: bool needAnyResponse); mtpRequestId wasSent(mtpMsgId msgId) const; - [[nodiscard]] HandleResult handleOneReceived(const mtpPrime *from, const mtpPrime *end, uint64 msgId, int32 serverTime, uint64 serverSalt, bool badTime); + struct OuterInfo { + mtpMsgId outerMsgId = 0; + uint64 serverSalt = 0; + int32 serverTime = 0; + bool badTime = false; + }; + [[nodiscard]] HandleResult handleOneReceived( + const mtpPrime *from, + const mtpPrime *end, + uint64 msgId, + OuterInfo info); [[nodiscard]] HandleResult handleBindResponse( mtpMsgId requestMsgId, const mtpBuffer &response); @@ -137,7 +147,7 @@ private: const bytes::vector &protocolSecret); // if badTime received - search for ids in sessionData->haveSent and sessionData->wereAcked and sync time/salt, return true if found - bool requestsFixTimeSalt(const QVector &ids, int32 serverTime, uint64 serverSalt); + bool requestsFixTimeSalt(const QVector &ids, const OuterInfo &info); // if we had a confirmed fast request use its unixtime as a correct one. void correctUnixtimeByFastRequest(