Allow getting outer_msg_id in MTProto handlers.

This commit is contained in:
John Preston 2021-03-12 14:45:13 +04:00
parent e681b0d95a
commit 267e5fd9e0
22 changed files with 756 additions and 1139 deletions

View file

@ -157,7 +157,7 @@ void EditMessageWithUploadedMedia(
void RescheduleMessage( void RescheduleMessage(
not_null<HistoryItem*> item, not_null<HistoryItem*> item,
SendOptions options) { SendOptions options) {
const auto empty = [](const auto &r) {}; const auto empty = [] {};
EditMessage(item, options, empty, empty); EditMessage(item, options, empty, empty);
} }

View file

@ -390,7 +390,7 @@ void ApiWrap::acceptTerms(bytes::const_span id) {
void ApiWrap::checkChatInvite( void ApiWrap::checkChatInvite(
const QString &hash, const QString &hash,
FnMut<void(const MTPChatInvite &)> done, FnMut<void(const MTPChatInvite &)> done,
FnMut<void(const RPCError &)> fail) { Fn<void(const RPCError &)> fail) {
request(base::take(_checkInviteRequestId)).cancel(); request(base::take(_checkInviteRequestId)).cancel();
_checkInviteRequestId = request(MTPmessages_CheckChatInvite( _checkInviteRequestId = request(MTPmessages_CheckChatInvite(
MTP_string(hash) MTP_string(hash)
@ -1227,7 +1227,7 @@ void ApiWrap::requestPeerSettings(not_null<PeerData*> peer) {
void ApiWrap::migrateChat( void ApiWrap::migrateChat(
not_null<ChatData*> chat, not_null<ChatData*> chat,
FnMut<void(not_null<ChannelData*>)> done, FnMut<void(not_null<ChannelData*>)> done,
FnMut<void(const RPCError &)> fail) { Fn<void(const RPCError &)> fail) {
const auto callback = [&] { const auto callback = [&] {
return MigrateCallbacks{ std::move(done), std::move(fail) }; return MigrateCallbacks{ std::move(done), std::move(fail) };
}; };

View file

@ -227,7 +227,7 @@ public:
void checkChatInvite( void checkChatInvite(
const QString &hash, const QString &hash,
FnMut<void(const MTPChatInvite &)> done, FnMut<void(const MTPChatInvite &)> done,
FnMut<void(const RPCError &)> fail); Fn<void(const RPCError &)> fail);
void importChatInvite(const QString &hash); void importChatInvite(const QString &hash);
void requestChannelMembersForAdd( void requestChannelMembersForAdd(
@ -243,7 +243,7 @@ public:
void migrateChat( void migrateChat(
not_null<ChatData*> chat, not_null<ChatData*> chat,
FnMut<void(not_null<ChannelData*>)> done, FnMut<void(not_null<ChannelData*>)> done,
FnMut<void(const RPCError &)> fail = nullptr); Fn<void(const RPCError &)> fail = nullptr);
void markMediaRead(const base::flat_set<not_null<HistoryItem*>> &items); void markMediaRead(const base::flat_set<not_null<HistoryItem*>> &items);
void markMediaRead(not_null<HistoryItem*> item); void markMediaRead(not_null<HistoryItem*> item);
@ -742,11 +742,11 @@ private:
mtpRequestId _checkInviteRequestId = 0; mtpRequestId _checkInviteRequestId = 0;
FnMut<void(const MTPChatInvite &result)> _checkInviteDone; FnMut<void(const MTPChatInvite &result)> _checkInviteDone;
FnMut<void(const RPCError &error)> _checkInviteFail; Fn<void(const RPCError &error)> _checkInviteFail;
struct MigrateCallbacks { struct MigrateCallbacks {
FnMut<void(not_null<ChannelData*>)> done; FnMut<void(not_null<ChannelData*>)> done;
FnMut<void(const RPCError&)> fail; Fn<void(const RPCError&)> fail;
}; };
base::flat_map< base::flat_map<
not_null<PeerData*>, not_null<PeerData*>,

View file

@ -249,26 +249,26 @@ public:
RequestBuilder( RequestBuilder(
Original &&builder, Original &&builder,
FnMut<void(RPCError&&)> commonFailHandler); Fn<void(const RPCError&)> commonFailHandler);
[[nodiscard]] RequestBuilder &done(FnMut<void()> &&handler); [[nodiscard]] RequestBuilder &done(FnMut<void()> &&handler);
[[nodiscard]] RequestBuilder &done( [[nodiscard]] RequestBuilder &done(
FnMut<void(Response &&)> &&handler); FnMut<void(Response &&)> &&handler);
[[nodiscard]] RequestBuilder &fail( [[nodiscard]] RequestBuilder &fail(
FnMut<bool(const RPCError &)> &&handler); Fn<bool(const RPCError&)> &&handler);
mtpRequestId send(); mtpRequestId send();
private: private:
Original _builder; Original _builder;
FnMut<void(RPCError&&)> _commonFailHandler; Fn<void(const RPCError&)> _commonFailHandler;
}; };
template <typename Request> template <typename Request>
ApiWrap::RequestBuilder<Request>::RequestBuilder( ApiWrap::RequestBuilder<Request>::RequestBuilder(
Original &&builder, Original &&builder,
FnMut<void(RPCError&&)> commonFailHandler) Fn<void(const RPCError&)> commonFailHandler)
: _builder(std::move(builder)) : _builder(std::move(builder))
, _commonFailHandler(std::move(commonFailHandler)) { , _commonFailHandler(std::move(commonFailHandler)) {
} }
@ -295,15 +295,15 @@ auto ApiWrap::RequestBuilder<Request>::done(
template <typename Request> template <typename Request>
auto ApiWrap::RequestBuilder<Request>::fail( auto ApiWrap::RequestBuilder<Request>::fail(
FnMut<bool(const RPCError &)> &&handler Fn<bool(const RPCError &)> &&handler
) -> RequestBuilder& { ) -> RequestBuilder& {
if (handler) { if (handler) {
auto &silence_warning = _builder.fail([ auto &silence_warning = _builder.fail([
common = base::take(_commonFailHandler), common = base::take(_commonFailHandler),
specific = std::move(handler) specific = std::move(handler)
](RPCError &&error) mutable { ](const RPCError &error) mutable {
if (!specific(error)) { if (!specific(error)) {
common(std::move(error)); common(error);
} }
}); });
} }
@ -364,7 +364,7 @@ auto ApiWrap::mainRequest(Request &&request) {
return RequestBuilder<MTPInvokeWithTakeout<Request>>( return RequestBuilder<MTPInvokeWithTakeout<Request>>(
std::move(original), std::move(original),
[=](RPCError &&result) { error(std::move(result)); }); [=](const RPCError &result) { error(result); });
} }
template <typename Request> template <typename Request>
@ -391,7 +391,7 @@ auto ApiWrap::fileRequest(const Data::FileLocation &location, int offset) {
location.data, location.data,
MTP_int(offset), MTP_int(offset),
MTP_int(kFileChunkSize)) MTP_int(kFileChunkSize))
)).fail([=](RPCError &&result) { )).fail([=](const RPCError &result) {
if (result.type() == qstr("TAKEOUT_FILE_EMPTY") if (result.type() == qstr("TAKEOUT_FILE_EMPTY")
&& _otherDataProcess != nullptr) { && _otherDataProcess != nullptr) {
filePartDone( filePartDone(
@ -688,11 +688,11 @@ void ApiWrap::startMainSession(FnMut<void()> done) {
return data.vid().v; return data.vid().v;
}); });
done(); done();
}).fail([=](RPCError &&result) { }).fail([=](const RPCError &result) {
error(std::move(result)); error(result);
}).toDC(MTP::ShiftDcId(0, MTP::kExportDcShift)).send(); }).toDC(MTP::ShiftDcId(0, MTP::kExportDcShift)).send();
}).fail([=](RPCError &&result) { }).fail([=](const RPCError &result) {
error(std::move(result)); error(result);
}).send(); }).send();
} }
@ -1945,12 +1945,13 @@ void ApiWrap::filePartUnavailable() {
base::take(_fileProcess)->done(QString()); base::take(_fileProcess)->done(QString());
} }
void ApiWrap::error(RPCError &&error) { void ApiWrap::error(const RPCError &error) {
_errors.fire(std::move(error)); _errors.fire_copy(error);
} }
void ApiWrap::error(const QString &text) { void ApiWrap::error(const QString &text) {
error(MTP_rpc_error(MTP_int(0), MTP_string("API_ERROR: " + text))); error(RPCError(
MTP_rpc_error(MTP_int(0), MTP_string("API_ERROR: " + text))));
} }
void ApiWrap::ioError(const Output::Result &result) { void ApiWrap::ioError(const Output::Result &result) {

View file

@ -203,7 +203,7 @@ private:
const Data::FileLocation &location, const Data::FileLocation &location,
int offset); int offset);
void error(RPCError &&error); void error(const RPCError &error);
void error(const QString &text); void error(const QString &text);
void ioError(const Output::Result &result); void ioError(const Output::Result &result);

View file

@ -138,8 +138,8 @@ ControllerObject::ControllerObject(
: _api(mtproto, weak.runner()) : _api(mtproto, weak.runner())
, _state(PasswordCheckState{}) { , _state(PasswordCheckState{}) {
_api.errors( _api.errors(
) | rpl::start_with_next([=](RPCError &&error) { ) | rpl::start_with_next([=](const RPCError &error) {
setState(ApiErrorState{ std::move(error) }); setState(ApiErrorState{ error });
}, _lifetime); }, _lifetime);
_api.ioErrors( _api.ioErrors(

View file

@ -420,18 +420,14 @@ void Account::startMtp(std::unique_ptr<MTP::Config> config) {
_mtpFields.mainDcId = _mtp->mainDcId(); _mtpFields.mainDcId = _mtp->mainDcId();
_mtp->setUpdatesHandler(::rpcDone([=]( _mtp->setUpdatesHandler([=](const MTP::Response &message) {
const mtpPrime *from, checkForUpdates(message) || checkForNewSession(message);
const mtpPrime *end) { });
return checkForUpdates(from, end) _mtp->setGlobalFailHandler([=](const RPCError &, const MTP::Response &) {
|| checkForNewSession(from, end);
}));
_mtp->setGlobalFailHandler(::rpcFail([=](const RPCError &error) {
if (const auto session = maybeSession()) { if (const auto session = maybeSession()) {
crl::on_main(session, [=] { logOut(); }); crl::on_main(session, [=] { logOut(); });
} }
return true; });
}));
_mtp->setStateChangedHandler([=](MTP::ShiftedDcId dc, int32 state) { _mtp->setStateChangedHandler([=](MTP::ShiftedDcId dc, int32 state) {
if (dc == _mtp->mainDcId()) { if (dc == _mtp->mainDcId()) {
Global::RefConnectionTypeChanged().notify(); Global::RefConnectionTypeChanged().notify();
@ -468,18 +464,20 @@ void Account::startMtp(std::unique_ptr<MTP::Config> config) {
_mtpValue = _mtp.get(); _mtpValue = _mtp.get();
} }
bool Account::checkForUpdates(const mtpPrime *from, const mtpPrime *end) { bool Account::checkForUpdates(const MTP::Response &message) {
auto updates = MTPUpdates(); auto updates = MTPUpdates();
if (!updates.read(from, end)) { auto from = message.reply.constData();
if (!updates.read(from, from + message.reply.size())) {
return false; return false;
} }
_mtpUpdates.fire(std::move(updates)); _mtpUpdates.fire(std::move(updates));
return true; return true;
} }
bool Account::checkForNewSession(const mtpPrime *from, const mtpPrime *end) { bool Account::checkForNewSession(const MTP::Response &message) {
auto newSession = MTPNewSession(); auto newSession = MTPNewSession();
if (!newSession.read(from, end)) { auto from = message.reply.constData();
if (!newSession.read(from, from + message.reply.size())) {
return false; return false;
} }
_mtpNewSessionCreated.fire({}); _mtpNewSessionCreated.fire({});

View file

@ -122,8 +122,8 @@ private:
std::unique_ptr<SessionSettings> settings); std::unique_ptr<SessionSettings> settings);
void watchProxyChanges(); void watchProxyChanges();
void watchSessionChanges(); void watchSessionChanges();
bool checkForUpdates(const mtpPrime *from, const mtpPrime *end); bool checkForUpdates(const MTP::Response &message);
bool checkForNewSession(const mtpPrime *from, const mtpPrime *end); bool checkForNewSession(const MTP::Response &message);
void destroyMtpKeys(MTP::AuthKeysList &&keys); void destroyMtpKeys(MTP::AuthKeysList &&keys);
void resetAuthorizationKeys(); void resetAuthorizationKeys();

View file

@ -27,8 +27,8 @@ constexpr auto kSpecialRequestTimeoutMs = 6000; // 4 seconds timeout for it to w
ConfigLoader::ConfigLoader( ConfigLoader::ConfigLoader(
not_null<Instance*> instance, not_null<Instance*> instance,
const QString &phone, const QString &phone,
RPCDoneHandlerPtr onDone, Fn<void(const MTPConfig &result)> onDone,
RPCFailHandlerPtr onFail) FailHandler onFail)
: _instance(instance) : _instance(instance)
, _phone(phone) , _phone(phone)
, _doneHandler(onDone) , _doneHandler(onDone)
@ -50,9 +50,18 @@ void ConfigLoader::load() {
} }
mtpRequestId ConfigLoader::sendRequest(ShiftedDcId shiftedDcId) { mtpRequestId ConfigLoader::sendRequest(ShiftedDcId shiftedDcId) {
auto done = [done = _doneHandler](const Response &response) {
auto from = response.reply.constData();
auto result = MTPConfig();
if (!result.read(from, from + response.reply.size())) {
return false;
}
done(result);
return true;
};
return _instance->send( return _instance->send(
MTPhelp_GetConfig(), MTPhelp_GetConfig(),
base::duplicate(_doneHandler), std::move(done),
base::duplicate(_failHandler), base::duplicate(_failHandler),
shiftedDcId); shiftedDcId);
} }
@ -191,11 +200,17 @@ void ConfigLoader::sendSpecialRequest() {
endpoint->secret); endpoint->secret);
_specialEnumRequest = _instance->send( _specialEnumRequest = _instance->send(
MTPhelp_GetConfig(), MTPhelp_GetConfig(),
rpcDone([weak](const MTPConfig &result) { [weak](const Response &response) {
auto result = MTPConfig();
auto from = response.reply.constData();
if (!result.read(from, from + response.reply.size())) {
return false;
}
if (const auto strong = weak.get()) { if (const auto strong = weak.get()) {
strong->specialConfigLoaded(result); strong->specialConfigLoaded(result);
} }
}), return true;
},
base::duplicate(_failHandler), base::duplicate(_failHandler),
_specialEnumCurrent); _specialEnumCurrent);
_triedSpecialEndpoints.push_back(*endpoint); _triedSpecialEndpoints.push_back(*endpoint);

View file

@ -25,8 +25,8 @@ public:
ConfigLoader( ConfigLoader(
not_null<Instance*> instance, not_null<Instance*> instance,
const QString &phone, const QString &phone,
RPCDoneHandlerPtr onDone, Fn<void(const MTPConfig &result)> onDone,
RPCFailHandlerPtr onFail); FailHandler onFail);
~ConfigLoader(); ~ConfigLoader();
void load(); void load();
@ -68,8 +68,8 @@ private:
mtpRequestId _specialEnumRequest = 0; mtpRequestId _specialEnumRequest = 0;
QString _phone; QString _phone;
RPCDoneHandlerPtr _doneHandler; Fn<void(const MTPConfig &result)> _doneHandler;
RPCFailHandlerPtr _failHandler; FailHandler _failHandler;
}; };

View file

@ -19,10 +19,10 @@ class WeakInstance : private QObject, private base::Subscriber {
public: public:
explicit WeakInstance(base::weak_ptr<Main::Session> session); explicit WeakInstance(base::weak_ptr<Main::Session> session);
template <typename T> template <typename Request>
void send( void send(
const T &request, const Request &request,
Fn<void(const typename T::ResponseType &result)> done, Fn<void(const typename Request::ResponseType &result)> done,
Fn<void(const RPCError &error)> fail, Fn<void(const RPCError &error)> fail,
ShiftedDcId dcId = 0); ShiftedDcId dcId = 0);
@ -162,39 +162,44 @@ void StartDedicatedLoader(
const QString &folder, const QString &folder,
Fn<void(std::unique_ptr<DedicatedLoader>)> ready); Fn<void(std::unique_ptr<DedicatedLoader>)> ready);
template <typename T> template <typename Request>
void WeakInstance::send( void WeakInstance::send(
const T &request, const Request &request,
Fn<void(const typename T::ResponseType &result)> done, Fn<void(const typename Request::ResponseType &result)> done,
Fn<void(const RPCError &error)> fail, Fn<void(const RPCError &error)> fail,
MTP::ShiftedDcId dcId) { MTP::ShiftedDcId dcId) {
using Response = typename T::ResponseType; using Result = typename Request::ResponseType;
if (!valid()) { if (!valid()) {
reportUnavailable(fail); reportUnavailable(fail);
return; return;
} }
const auto onDone = crl::guard((QObject*)this, [=]( const auto onDone = crl::guard((QObject*)this, [=](
const Response &result, const Response &response) {
mtpRequestId requestId) { auto result = Result();
if (removeRequest(requestId)) { auto from = response.reply.constData();
if (!result.read(from, from + response.reply.size())) {
return false;
}
if (removeRequest(response.requestId)) {
done(result); done(result);
} }
return true;
}); });
const auto onFail = crl::guard((QObject*)this, [=]( const auto onFail = crl::guard((QObject*)this, [=](
const RPCError &error, const RPCError &error,
mtpRequestId requestId) { const Response &response) {
if (MTP::isDefaultHandledError(error)) { if (MTP::isDefaultHandledError(error)) {
return false; return false;
} }
if (removeRequest(requestId)) { if (removeRequest(response.requestId)) {
fail(error); fail(error);
} }
return true; return true;
}); });
const auto requestId = _instance->send( const auto requestId = _instance->send(
request, request,
rpcDone(onDone), std::move(onDone),
rpcFail(onFail), std::move(onFail),
dcId); dcId);
_requests.emplace(requestId, fail); _requests.emplace(requestId, fail);
} }

View file

@ -119,7 +119,7 @@ public:
void sendRequest( void sendRequest(
mtpRequestId requestId, mtpRequestId requestId,
SerializedRequest &&request, SerializedRequest &&request,
RPCResponseHandler &&callbacks, ResponseHandler &&callbacks,
ShiftedDcId shiftedDcId, ShiftedDcId shiftedDcId,
crl::time msCanWait, crl::time msCanWait,
bool needsLayer, bool needsLayer,
@ -129,23 +129,30 @@ public:
void storeRequest( void storeRequest(
mtpRequestId requestId, mtpRequestId requestId,
const SerializedRequest &request, const SerializedRequest &request,
RPCResponseHandler &&callbacks); ResponseHandler &&callbacks);
SerializedRequest getRequest(mtpRequestId requestId); SerializedRequest getRequest(mtpRequestId requestId);
void execCallback(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end); [[nodiscard]] bool hasCallback(mtpRequestId requestId) const;
bool hasCallbacks(mtpRequestId requestId); void processCallback(const Response &response);
void globalCallback(const mtpPrime *from, const mtpPrime *end); void processUpdate(const Response &message);
void onStateChange(ShiftedDcId shiftedDcId, int32 state); void onStateChange(ShiftedDcId shiftedDcId, int32 state);
void onSessionReset(ShiftedDcId shiftedDcId); void onSessionReset(ShiftedDcId shiftedDcId);
// return true if need to clean request data // return true if need to clean request data
bool rpcErrorOccured(mtpRequestId requestId, const RPCFailHandlerPtr &onFail, const RPCError &err); bool rpcErrorOccured(
inline bool rpcErrorOccured(mtpRequestId requestId, const RPCResponseHandler &handler, const RPCError &err) { const Response &response,
return rpcErrorOccured(requestId, handler.onFail, err); const FailHandler &onFail,
const RPCError &error);
inline bool rpcErrorOccured(
const Response &response,
const ResponseHandler &handler,
const RPCError &error) {
return rpcErrorOccured(response, handler.fail, error);
} }
void setUpdatesHandler(RPCDoneHandlerPtr onDone); void setUpdatesHandler(Fn<void(const Response&)> handler);
void setGlobalFailHandler(RPCFailHandlerPtr onFail); void setGlobalFailHandler(
Fn<void(const RPCError&, const Response&)> handler);
void setStateChangedHandler(Fn<void(ShiftedDcId shiftedDcId, int32 state)> handler); void setStateChangedHandler(Fn<void(ShiftedDcId shiftedDcId, int32 state)> handler);
void setSessionResetHandler(Fn<void(ShiftedDcId shiftedDcId)> handler); void setSessionResetHandler(Fn<void(ShiftedDcId shiftedDcId)> handler);
void clearGlobalHandlers(); void clearGlobalHandlers();
@ -170,11 +177,15 @@ public:
[[nodiscard]] rpl::lifetime &lifetime(); [[nodiscard]] rpl::lifetime &lifetime();
private: private:
void importDone(const MTPauth_Authorization &result, mtpRequestId requestId); void importDone(
bool importFail(const RPCError &error, mtpRequestId requestId); const MTPauth_Authorization &result,
void exportDone(const MTPauth_ExportedAuthorization &result, mtpRequestId requestId); const Response &response);
bool exportFail(const RPCError &error, mtpRequestId requestId); bool importFail(const RPCError &error, const Response &response);
bool onErrorDefault(mtpRequestId requestId, const RPCError &error); void exportDone(
const MTPauth_ExportedAuthorization &result,
const Response &response);
bool exportFail(const RPCError &error, const Response &response);
bool onErrorDefault(const RPCError &error, const Response &response);
void unpaused(); void unpaused();
@ -245,8 +256,8 @@ private:
// holds target dcWithShift for auth export request // holds target dcWithShift for auth export request
std::map<mtpRequestId, ShiftedDcId> _authExportRequests; std::map<mtpRequestId, ShiftedDcId> _authExportRequests;
std::map<mtpRequestId, RPCResponseHandler> _parserMap; std::map<mtpRequestId, ResponseHandler> _parserMap;
QMutex _parserMapLock; mutable QMutex _parserMapLock;
std::map<mtpRequestId, SerializedRequest> _requestMap; std::map<mtpRequestId, SerializedRequest> _requestMap;
QReadWriteLock _requestMapLock; QReadWriteLock _requestMapLock;
@ -259,7 +270,8 @@ private:
std::map<DcId, std::vector<mtpRequestId>> _authWaiters; std::map<DcId, std::vector<mtpRequestId>> _authWaiters;
RPCResponseHandler _globalHandler; Fn<void(const Response&)> _updatesHandler;
Fn<void(const RPCError&, const Response&)> _globalFailHandler;
Fn<void(ShiftedDcId shiftedDcId, int32 state)> _stateChangedHandler; Fn<void(ShiftedDcId shiftedDcId, int32 state)> _stateChangedHandler;
Fn<void(ShiftedDcId shiftedDcId)> _sessionResetHandler; Fn<void(ShiftedDcId shiftedDcId)> _sessionResetHandler;
@ -450,8 +462,10 @@ void Instance::Private::requestConfig() {
_configLoader = std::make_unique<ConfigLoader>( _configLoader = std::make_unique<ConfigLoader>(
_instance, _instance,
_userPhone, _userPhone,
rpcDone([=](const MTPConfig &result) { configLoadDone(result); }), [=](const MTPConfig &result) { configLoadDone(result); },
rpcFail([=](const RPCError &error) { return configLoadFail(error); })); [=](const RPCError &error, const Response &) {
return configLoadFail(error);
});
_configLoader->load(); _configLoader->load();
} }
@ -647,12 +661,13 @@ void Instance::Private::reInitConnection(DcId dcId) {
} }
void Instance::Private::logout(Fn<void()> done) { void Instance::Private::logout(Fn<void()> done) {
_instance->send(MTPauth_LogOut(), rpcDone([=] { _instance->send(MTPauth_LogOut(), [=](Response) {
done();
}), rpcFail([=] {
done(); done();
return true; return true;
})); }, [=](const RPCError&, Response) {
done();
return true;
});
logoutGuestDcs(); logoutGuestDcs();
} }
@ -667,12 +682,14 @@ void Instance::Private::logoutGuestDcs() {
continue; continue;
} }
const auto shiftedDcId = MTP::logoutDcId(dcId); const auto shiftedDcId = MTP::logoutDcId(dcId);
const auto requestId = _instance->send(MTPauth_LogOut(), rpcDone([=]( const auto requestId = _instance->send(MTPauth_LogOut(), [=](
mtpRequestId requestId) { const Response &response) {
logoutGuestDone(requestId); logoutGuestDone(response.requestId);
}), rpcFail([=](mtpRequestId requestId) { return true;
return logoutGuestDone(requestId); }, [=](const RPCError &, const Response &response) {
}), shiftedDcId); logoutGuestDone(response.requestId);
return true;
}, shiftedDcId);
_logoutGuestRequestIds.emplace(shiftedDcId, requestId); _logoutGuestRequestIds.emplace(shiftedDcId, requestId);
} }
} }
@ -932,7 +949,7 @@ void Instance::Private::checkDelayedRequests() {
void Instance::Private::sendRequest( void Instance::Private::sendRequest(
mtpRequestId requestId, mtpRequestId requestId,
SerializedRequest &&request, SerializedRequest &&request,
RPCResponseHandler &&callbacks, ResponseHandler &&callbacks,
ShiftedDcId shiftedDcId, ShiftedDcId shiftedDcId,
crl::time msCanWait, crl::time msCanWait,
bool needsLayer, bool needsLayer,
@ -980,8 +997,8 @@ void Instance::Private::unregisterRequest(mtpRequestId requestId) {
void Instance::Private::storeRequest( void Instance::Private::storeRequest(
mtpRequestId requestId, mtpRequestId requestId,
const SerializedRequest &request, const SerializedRequest &request,
RPCResponseHandler &&callbacks) { ResponseHandler &&callbacks) {
if (callbacks.onDone || callbacks.onFail) { if (callbacks.done || callbacks.fail) {
QMutexLocker locker(&_parserMapLock); QMutexLocker locker(&_parserMapLock);
_parserMap.emplace(requestId, std::move(callbacks)); _parserMap.emplace(requestId, std::move(callbacks));
} }
@ -1003,53 +1020,58 @@ SerializedRequest Instance::Private::getRequest(mtpRequestId requestId) {
return result; return result;
} }
bool Instance::Private::hasCallback(mtpRequestId requestId) const {
QMutexLocker locker(&_parserMapLock);
auto it = _parserMap.find(requestId);
return (it != _parserMap.cend());
}
void Instance::Private::execCallback( void Instance::Private::processCallback(const Response &response) {
mtpRequestId requestId, const auto requestId = response.requestId;
const mtpPrime *from, ResponseHandler handler;
const mtpPrime *end) {
RPCResponseHandler h;
{ {
QMutexLocker locker(&_parserMapLock); QMutexLocker locker(&_parserMapLock);
auto it = _parserMap.find(requestId); auto it = _parserMap.find(requestId);
if (it != _parserMap.cend()) { if (it != _parserMap.cend()) {
h = it->second; handler = std::move(it->second);
_parserMap.erase(it); _parserMap.erase(it);
DEBUG_LOG(("RPC Info: found parser for request %1, trying to parse response...").arg(requestId)); DEBUG_LOG(("RPC Info: found parser for request %1, trying to parse response...").arg(requestId));
} }
} }
if (h.onDone || h.onFail) { if (handler.done || handler.fail) {
const auto handleError = [&](const RPCError &error) { const auto handleError = [&](const RPCError &error) {
DEBUG_LOG(("RPC Info: " DEBUG_LOG(("RPC Info: "
"error received, code %1, type %2, description: %3" "error received, code %1, type %2, description: %3"
).arg(error.code() ).arg(error.code()
).arg(error.type() ).arg(error.type()
).arg(error.description())); ).arg(error.description()));
if (rpcErrorOccured(requestId, h, error)) { if (rpcErrorOccured(response, handler, error)) {
unregisterRequest(requestId); unregisterRequest(requestId);
} else { } else {
QMutexLocker locker(&_parserMapLock); QMutexLocker locker(&_parserMapLock);
_parserMap.emplace(requestId, h); _parserMap.emplace(requestId, std::move(handler));
} }
}; };
if (from >= end) { auto from = response.reply.constData();
if (response.reply.isEmpty()) {
handleError(RPCError::Local( handleError(RPCError::Local(
"RESPONSE_PARSE_FAILED", "RESPONSE_PARSE_FAILED",
"Empty response.")); "Empty response."));
} else if (*from == mtpc_rpc_error) { } else if (*from == mtpc_rpc_error) {
auto error = MTPRpcError(); auto error = MTPRpcError();
handleError(error.read(from, end) ? error : RPCError::Local( handleError(
"RESPONSE_PARSE_FAILED", RPCError(error.read(from, from + response.reply.size())
"Error parse failed.")); ? error
} else { : RPCError::MTPLocal(
if (h.onDone) {
if (!(*h.onDone)(requestId, from, end)) {
handleError(RPCError::Local(
"RESPONSE_PARSE_FAILED", "RESPONSE_PARSE_FAILED",
"Response parse failed.")); "Error parse failed.")));
} } else {
if (handler.done && !handler.done(response)) {
handleError(RPCError::Local(
"RESPONSE_PARSE_FAILED",
"Response parse failed."));
} }
unregisterRequest(requestId); unregisterRequest(requestId);
} }
@ -1059,18 +1081,10 @@ void Instance::Private::execCallback(
} }
} }
bool Instance::Private::hasCallbacks(mtpRequestId requestId) { void Instance::Private::processUpdate(const Response &message) {
QMutexLocker locker(&_parserMapLock); if (_updatesHandler) {
auto it = _parserMap.find(requestId); _updatesHandler(message);
return (it != _parserMap.cend());
}
void Instance::Private::globalCallback(const mtpPrime *from, const mtpPrime *end) {
if (!_globalHandler.onDone) {
return;
} }
// Handle updates.
[[maybe_unused]] bool result = (*_globalHandler.onDone)(0, from, end);
} }
void Instance::Private::onStateChange(ShiftedDcId dcWithShift, int32 state) { void Instance::Private::onStateChange(ShiftedDcId dcWithShift, int32 state) {
@ -1085,25 +1099,40 @@ void Instance::Private::onSessionReset(ShiftedDcId dcWithShift) {
} }
} }
bool Instance::Private::rpcErrorOccured(mtpRequestId requestId, const RPCFailHandlerPtr &onFail, const RPCError &err) { // return true if need to clean request data bool Instance::Private::rpcErrorOccured(
if (isDefaultHandledError(err)) { const Response &response,
if (onFail && (*onFail)(requestId, err)) { const FailHandler &onFail,
const RPCError &error) { // return true if need to clean request data
if (isDefaultHandledError(error)) {
if (onFail && onFail(error, response)) {
return true; return true;
} }
} }
if (onErrorDefault(requestId, err)) { if (onErrorDefault(error, response)) {
return false; return false;
} }
LOG(("RPC Error: request %1 got fail with code %2, error %3%4").arg(requestId).arg(err.code()).arg(err.type()).arg(err.description().isEmpty() ? QString() : QString(": %1").arg(err.description()))); LOG(("RPC Error: request %1 got fail with code %2, error %3%4"
onFail && (*onFail)(requestId, err); ).arg(response.requestId
).arg(error.code()
).arg(error.type()
).arg(error.description().isEmpty()
? QString()
: QString(": %1").arg(error.description())));
if (onFail) {
onFail(error, response);
}
return true; return true;
} }
void Instance::Private::importDone(const MTPauth_Authorization &result, mtpRequestId requestId) { void Instance::Private::importDone(
const auto shiftedDcId = queryRequestByDc(requestId); const MTPauth_Authorization &result,
const Response &response) {
const auto shiftedDcId = queryRequestByDc(response.requestId);
if (!shiftedDcId) { if (!shiftedDcId) {
LOG(("MTP Error: auth import request not found in requestsByDC, requestId: %1").arg(requestId)); LOG(("MTP Error: "
"auth import request not found in requestsByDC, requestId: %1"
).arg(response.requestId));
// //
// Don't log out on export/import problems, perhaps this is a server side error. // Don't log out on export/import problems, perhaps this is a server side error.
// //
@ -1111,8 +1140,8 @@ void Instance::Private::importDone(const MTPauth_Authorization &result, mtpReque
// "AUTH_IMPORT_FAIL", // "AUTH_IMPORT_FAIL",
// QString("did not find import request in requestsByDC, " // QString("did not find import request in requestsByDC, "
// "request %1").arg(requestId)); // "request %1").arg(requestId));
//if (_globalHandler.onFail && hasAuthorization()) { //if (_globalFailHandler && hasAuthorization()) {
// (*_globalHandler.onFail)(requestId, error); // auth failed in main dc // _globalFailHandler(error, response); // auth failed in main dc
//} //}
return; return;
} }
@ -1144,22 +1173,30 @@ void Instance::Private::importDone(const MTPauth_Authorization &result, mtpReque
} }
} }
bool Instance::Private::importFail(const RPCError &error, mtpRequestId requestId) { bool Instance::Private::importFail(
if (isDefaultHandledError(error)) return false; const RPCError &error,
const Response &response) {
if (isDefaultHandledError(error)) {
return false;
}
// //
// Don't log out on export/import problems, perhaps this is a server side error. // Don't log out on export/import problems, perhaps this is a server side error.
// //
//if (_globalHandler.onFail && hasAuthorization()) { //if (_globalFailHandler && hasAuthorization()) {
// (*_globalHandler.onFail)(requestId, error); // auth import failed // _globalFailHandler(error, response); // auth import failed
//} //}
return true; return true;
} }
void Instance::Private::exportDone(const MTPauth_ExportedAuthorization &result, mtpRequestId requestId) { void Instance::Private::exportDone(
auto it = _authExportRequests.find(requestId); const MTPauth_ExportedAuthorization &result,
const Response &response) {
auto it = _authExportRequests.find(response.requestId);
if (it == _authExportRequests.cend()) { if (it == _authExportRequests.cend()) {
LOG(("MTP Error: auth export request target dcWithShift not found, requestId: %1").arg(requestId)); LOG(("MTP Error: "
"auth export request target dcWithShift not found, requestId: %1"
).arg(response.requestId));
// //
// Don't log out on export/import problems, perhaps this is a server side error. // Don't log out on export/import problems, perhaps this is a server side error.
// //
@ -1167,46 +1204,62 @@ void Instance::Private::exportDone(const MTPauth_ExportedAuthorization &result,
// "AUTH_IMPORT_FAIL", // "AUTH_IMPORT_FAIL",
// QString("did not find target dcWithShift, request %1" // QString("did not find target dcWithShift, request %1"
// ).arg(requestId)); // ).arg(requestId));
//if (_globalHandler.onFail && hasAuthorization()) { //if (_globalFailHandler && hasAuthorization()) {
// (*_globalHandler.onFail)(requestId, error); // auth failed in main dc // _globalFailHandler(error, response); // auth failed in main dc
//} //}
return; return;
} }
auto &data = result.c_auth_exportedAuthorization(); auto &data = result.c_auth_exportedAuthorization();
_instance->send(MTPauth_ImportAuthorization(data.vid(), data.vbytes()), rpcDone([this](const MTPauth_Authorization &result, mtpRequestId requestId) { _instance->send(MTPauth_ImportAuthorization(
importDone(result, requestId); data.vid(),
}), rpcFail([this](const RPCError &error, mtpRequestId requestId) { data.vbytes()
return importFail(error, requestId); ), [this](const Response &response) {
}), it->second); auto result = MTPauth_Authorization();
_authExportRequests.erase(requestId); auto from = response.reply.constData();
if (!result.read(from, from + response.reply.size())) {
return false;
}
importDone(result, response);
return true;
}, [this](const RPCError &error, const Response &response) {
return importFail(error, response);
}, it->second);
_authExportRequests.erase(response.requestId);
} }
bool Instance::Private::exportFail(const RPCError &error, mtpRequestId requestId) { bool Instance::Private::exportFail(
if (isDefaultHandledError(error)) return false; const RPCError &error,
const Response &response) {
if (isDefaultHandledError(error)) {
return false;
}
auto it = _authExportRequests.find(requestId); auto it = _authExportRequests.find(response.requestId);
if (it != _authExportRequests.cend()) { if (it != _authExportRequests.cend()) {
_authWaiters[BareDcId(it->second)].clear(); _authWaiters[BareDcId(it->second)].clear();
} }
// //
// Don't log out on export/import problems, perhaps this is a server side error. // Don't log out on export/import problems, perhaps this is a server side error.
// //
//if (_globalHandler.onFail && hasAuthorization()) { //if (_globalFailHandler && hasAuthorization()) {
// (*_globalHandler.onFail)(requestId, error); // auth failed in main dc // _globalFailHandler(error, response); // auth failed in main dc
//} //}
return true; return true;
} }
bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &error) { bool Instance::Private::onErrorDefault(
auto &err(error.type()); const RPCError &error,
auto code = error.code(); const Response &response) {
if (!isFloodError(error) && err != qstr("AUTH_KEY_UNREGISTERED")) { const auto requestId = response.requestId;
const auto &type = error.type();
const auto code = error.code();
if (!isFloodError(error) && type != qstr("AUTH_KEY_UNREGISTERED")) {
int breakpoint = 0; int breakpoint = 0;
} }
auto badGuestDc = (code == 400) && (err == qsl("FILE_ID_INVALID")); auto badGuestDc = (code == 400) && (type == qsl("FILE_ID_INVALID"));
QRegularExpressionMatch m; QRegularExpressionMatch m;
if ((m = QRegularExpression("^(FILE|PHONE|NETWORK|USER)_MIGRATE_(\\d+)$").match(err)).hasMatch()) { if ((m = QRegularExpression("^(FILE|PHONE|NETWORK|USER)_MIGRATE_(\\d+)$").match(type)).hasMatch()) {
if (!requestId) return false; if (!requestId) return false;
auto dcWithShift = ShiftedDcId(0); auto dcWithShift = ShiftedDcId(0);
@ -1228,11 +1281,19 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e
//DEBUG_LOG(("MTP Info: importing auth to dc %1").arg(newdcWithShift)); //DEBUG_LOG(("MTP Info: importing auth to dc %1").arg(newdcWithShift));
//auto &waiters(_authWaiters[newdcWithShift]); //auto &waiters(_authWaiters[newdcWithShift]);
//if (waiters.empty()) { //if (waiters.empty()) {
// auto exportRequestId = _instance->send(MTPauth_ExportAuthorization(MTP_int(newdcWithShift)), rpcDone([this](const MTPauth_ExportedAuthorization &result, mtpRequestId requestId) { // auto exportRequestId = _instance->send(MTPauth_ExportAuthorization(
// exportDone(result, requestId); // MTP_int(newdcWithShift)
// }), rpcFail([this](const RPCError &error, mtpRequestId requestId) { // ), [this](const Response &response) {
// return exportFail(error, requestId); // auto result = MTPauth_ExportedAuthorization();
// })); // auto from = response.reply.constData();
// if (!result.read(from, from + response.reply.size())) {
// return false;
// }
// exportDone(result, response);
// return true;
// }, [this](const RPCError &error, const Response &response) {
// return exportFail(error, response);
// });
// _authExportRequests.emplace(exportRequestId, newdcWithShift); // _authExportRequests.emplace(exportRequestId, newdcWithShift);
//} //}
//waiters.push_back(requestId); //waiters.push_back(requestId);
@ -1260,7 +1321,7 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e
(dcWithShift < 0) ? -newdcWithShift : newdcWithShift); (dcWithShift < 0) ? -newdcWithShift : newdcWithShift);
session->sendPrepared(request); session->sendPrepared(request);
return true; return true;
} else if (code < 0 || code >= 500 || (m = QRegularExpression("^FLOOD_WAIT_(\\d+)$").match(err)).hasMatch()) { } else if (code < 0 || code >= 500 || (m = QRegularExpression("^FLOOD_WAIT_(\\d+)$").match(type)).hasMatch()) {
if (!requestId) return false; if (!requestId) return false;
int32 secs = 1; int32 secs = 1;
@ -1286,7 +1347,7 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e
checkDelayedRequests(); checkDelayedRequests();
return true; return true;
} else if ((code == 401 && err != "AUTH_KEY_PERM_EMPTY") } else if ((code == 401 && type != qstr("AUTH_KEY_PERM_EMPTY"))
|| (badGuestDc && _badGuestDcRequests.find(requestId) == _badGuestDcRequests.cend())) { || (badGuestDc && _badGuestDcRequests.find(requestId) == _badGuestDcRequests.cend())) {
auto dcWithShift = ShiftedDcId(0); auto dcWithShift = ShiftedDcId(0);
if (const auto shiftedDcId = queryRequestByDc(requestId)) { if (const auto shiftedDcId = queryRequestByDc(requestId)) {
@ -1296,26 +1357,36 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e
} }
auto newdc = BareDcId(qAbs(dcWithShift)); auto newdc = BareDcId(qAbs(dcWithShift));
if (!newdc || newdc == mainDcId()) { if (!newdc || newdc == mainDcId()) {
if (!badGuestDc && _globalHandler.onFail) { if (!badGuestDc && _globalFailHandler) {
(*_globalHandler.onFail)(requestId, error); // auth failed in main dc _globalFailHandler(error, response); // auth failed in main dc
} }
return false; return false;
} }
DEBUG_LOG(("MTP Info: importing auth to dcWithShift %1").arg(dcWithShift)); DEBUG_LOG(("MTP Info: importing auth to dcWithShift %1"
).arg(dcWithShift));
auto &waiters(_authWaiters[newdc]); auto &waiters(_authWaiters[newdc]);
if (!waiters.size()) { if (!waiters.size()) {
auto exportRequestId = _instance->send(MTPauth_ExportAuthorization(MTP_int(newdc)), rpcDone([this](const MTPauth_ExportedAuthorization &result, mtpRequestId requestId) { auto exportRequestId = _instance->send(MTPauth_ExportAuthorization(
exportDone(result, requestId); MTP_int(newdc)
}), rpcFail([this](const RPCError &error, mtpRequestId requestId) { ), [this](const Response &response) {
return exportFail(error, requestId); auto result = MTPauth_ExportedAuthorization();
})); auto from = response.reply.constData();
if (!result.read(from, from + response.reply.size())) {
return false;
}
exportDone(result, response);
return true;
}, [this](const RPCError &error, const Response &response) {
return exportFail(error, response);
});
_authExportRequests.emplace(exportRequestId, abs(dcWithShift)); _authExportRequests.emplace(exportRequestId, abs(dcWithShift));
} }
waiters.push_back(requestId); waiters.push_back(requestId);
if (badGuestDc) _badGuestDcRequests.insert(requestId); if (badGuestDc) _badGuestDcRequests.insert(requestId);
return true; return true;
} else if (err == qstr("CONNECTION_NOT_INITED") || err == qstr("CONNECTION_LAYER_INVALID")) { } else if (type == qstr("CONNECTION_NOT_INITED")
|| type == qstr("CONNECTION_LAYER_INVALID")) {
SerializedRequest request; SerializedRequest request;
{ {
QReadLocker locker(&_requestMapLock); QReadLocker locker(&_requestMapLock);
@ -1338,9 +1409,9 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e
request->needsLayer = true; request->needsLayer = true;
session->sendPrepared(request); session->sendPrepared(request);
return true; return true;
} else if (err == qstr("CONNECTION_LANG_CODE_INVALID")) { } else if (type == qstr("CONNECTION_LANG_CODE_INVALID")) {
Lang::CurrentCloudManager().resetToDefault(); Lang::CurrentCloudManager().resetToDefault();
} else if (err == qstr("MSG_WAIT_FAILED")) { } else if (type == qstr("MSG_WAIT_FAILED")) {
SerializedRequest request; SerializedRequest request;
{ {
QReadLocker locker(&_requestMapLock); QReadLocker locker(&_requestMapLock);
@ -1514,15 +1585,16 @@ void Instance::Private::scheduleKeyDestroy(ShiftedDcId shiftedDcId) {
if (dcOptions().dcType(shiftedDcId) == DcType::Cdn) { if (dcOptions().dcType(shiftedDcId) == DcType::Cdn) {
performKeyDestroy(shiftedDcId); performKeyDestroy(shiftedDcId);
} else { } else {
_instance->send(MTPauth_LogOut(), rpcDone([=](const MTPBool &) { _instance->send(MTPauth_LogOut(), [=](const Response &) {
performKeyDestroy(shiftedDcId); performKeyDestroy(shiftedDcId);
}), rpcFail([=](const RPCError &error) { return true;
}, [=](const RPCError &error, const Response &) {
if (isDefaultHandledError(error)) { if (isDefaultHandledError(error)) {
return false; return false;
} }
performKeyDestroy(shiftedDcId); performKeyDestroy(shiftedDcId);
return true; return true;
}), shiftedDcId); }, shiftedDcId);
} }
} }
@ -1539,21 +1611,29 @@ void Instance::Private::keyWasPossiblyDestroyed(ShiftedDcId shiftedDcId) {
void Instance::Private::performKeyDestroy(ShiftedDcId shiftedDcId) { void Instance::Private::performKeyDestroy(ShiftedDcId shiftedDcId) {
Expects(isKeysDestroyer()); Expects(isKeysDestroyer());
_instance->send(MTPDestroy_auth_key(), rpcDone([=](const MTPDestroyAuthKeyRes &result) { _instance->send(MTPDestroy_auth_key(), [=](const Response &response) {
switch (result.type()) { auto result = MTPDestroyAuthKeyRes();
case mtpc_destroy_auth_key_ok: LOG(("MTP Info: key %1 destroyed.").arg(shiftedDcId)); break; auto from = response.reply.constData();
case mtpc_destroy_auth_key_fail: { if (!result.read(from, from + response.reply.size())) {
LOG(("MTP Error: key %1 destruction fail, leave it for now.").arg(shiftedDcId)); return false;
killSession(shiftedDcId);
} break;
case mtpc_destroy_auth_key_none: LOG(("MTP Info: key %1 already destroyed.").arg(shiftedDcId)); break;
} }
_instance->keyWasPossiblyDestroyed(shiftedDcId); result.match([&](const MTPDdestroy_auth_key_ok &) {
}), rpcFail([=](const RPCError &error) { LOG(("MTP Info: key %1 destroyed.").arg(shiftedDcId));
LOG(("MTP Error: key %1 destruction resulted in error: %2").arg(shiftedDcId).arg(error.type())); }, [&](const MTPDdestroy_auth_key_fail &) {
LOG(("MTP Error: key %1 destruction fail, leave it for now."
).arg(shiftedDcId));
killSession(shiftedDcId);
}, [&](const MTPDdestroy_auth_key_none &) {
LOG(("MTP Info: key %1 already destroyed.").arg(shiftedDcId));
});
_instance->keyWasPossiblyDestroyed(shiftedDcId); _instance->keyWasPossiblyDestroyed(shiftedDcId);
return true; return true;
}), shiftedDcId); }, [=](const RPCError &error, const Response &response) {
LOG(("MTP Error: key %1 destruction resulted in error: %2"
).arg(shiftedDcId).arg(error.type()));
_instance->keyWasPossiblyDestroyed(shiftedDcId);
return true;
}, shiftedDcId);
} }
void Instance::Private::completedKeyDestroy(ShiftedDcId shiftedDcId) { void Instance::Private::completedKeyDestroy(ShiftedDcId shiftedDcId) {
@ -1582,27 +1662,31 @@ void Instance::Private::keyDestroyedOnServer(
restart(shiftedDcId); restart(shiftedDcId);
} }
void Instance::Private::setUpdatesHandler(RPCDoneHandlerPtr onDone) { void Instance::Private::setUpdatesHandler(
_globalHandler.onDone = onDone; Fn<void(const Response&)> handler) {
_updatesHandler = std::move(handler);
} }
void Instance::Private::setGlobalFailHandler(RPCFailHandlerPtr onFail) { void Instance::Private::setGlobalFailHandler(
_globalHandler.onFail = onFail; Fn<void(const RPCError&, const Response&)> handler) {
_globalFailHandler = std::move(handler);
} }
void Instance::Private::setStateChangedHandler(Fn<void(ShiftedDcId shiftedDcId, int32 state)> handler) { void Instance::Private::setStateChangedHandler(
Fn<void(ShiftedDcId shiftedDcId, int32 state)> handler) {
_stateChangedHandler = std::move(handler); _stateChangedHandler = std::move(handler);
} }
void Instance::Private::setSessionResetHandler(Fn<void(ShiftedDcId shiftedDcId)> handler) { void Instance::Private::setSessionResetHandler(
Fn<void(ShiftedDcId shiftedDcId)> handler) {
_sessionResetHandler = std::move(handler); _sessionResetHandler = std::move(handler);
} }
void Instance::Private::clearGlobalHandlers() { void Instance::Private::clearGlobalHandlers() {
setUpdatesHandler(RPCDoneHandlerPtr()); setUpdatesHandler(nullptr);
setGlobalFailHandler(RPCFailHandlerPtr()); setGlobalFailHandler(nullptr);
setStateChangedHandler(Fn<void(ShiftedDcId,int32)>()); setStateChangedHandler(nullptr);
setSessionResetHandler(Fn<void(ShiftedDcId)>()); setSessionResetHandler(nullptr);
} }
void Instance::Private::prepareToDestroy() { void Instance::Private::prepareToDestroy() {
@ -1806,19 +1890,22 @@ QString Instance::systemVersion() const {
return _private->systemVersion(); return _private->systemVersion();
} }
void Instance::setUpdatesHandler(RPCDoneHandlerPtr onDone) { void Instance::setUpdatesHandler(Fn<void(const Response&)> handler) {
_private->setUpdatesHandler(onDone); _private->setUpdatesHandler(std::move(handler));
} }
void Instance::setGlobalFailHandler(RPCFailHandlerPtr onFail) { void Instance::setGlobalFailHandler(
_private->setGlobalFailHandler(onFail); Fn<void(const RPCError&, const Response&)> handler) {
_private->setGlobalFailHandler(std::move(handler));
} }
void Instance::setStateChangedHandler(Fn<void(ShiftedDcId shiftedDcId, int32 state)> handler) { void Instance::setStateChangedHandler(
Fn<void(ShiftedDcId shiftedDcId, int32 state)> handler) {
_private->setStateChangedHandler(std::move(handler)); _private->setStateChangedHandler(std::move(handler));
} }
void Instance::setSessionResetHandler(Fn<void(ShiftedDcId shiftedDcId)> handler) { void Instance::setSessionResetHandler(
Fn<void(ShiftedDcId shiftedDcId)> handler) {
_private->setSessionResetHandler(std::move(handler)); _private->setSessionResetHandler(std::move(handler));
} }
@ -1834,20 +1921,23 @@ void Instance::onSessionReset(ShiftedDcId shiftedDcId) {
_private->onSessionReset(shiftedDcId); _private->onSessionReset(shiftedDcId);
} }
void Instance::execCallback(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) { bool Instance::hasCallback(mtpRequestId requestId) const {
_private->execCallback(requestId, from, end); return _private->hasCallback(requestId);
} }
bool Instance::hasCallbacks(mtpRequestId requestId) { void Instance::processCallback(const Response &response) {
return _private->hasCallbacks(requestId); _private->processCallback(response);
} }
void Instance::globalCallback(const mtpPrime *from, const mtpPrime *end) { void Instance::processUpdate(const Response &message) {
_private->globalCallback(from, end); _private->processUpdate(message);
} }
bool Instance::rpcErrorOccured(mtpRequestId requestId, const RPCFailHandlerPtr &onFail, const RPCError &err) { bool Instance::rpcErrorOccured(
return _private->rpcErrorOccured(requestId, onFail, err); const Response &response,
const FailHandler &onFail,
const RPCError &error) {
return _private->rpcErrorOccured(response, onFail, error);
} }
bool Instance::isKeysDestroyer() const { bool Instance::isKeysDestroyer() const {
@ -1865,7 +1955,7 @@ void Instance::keyDestroyedOnServer(ShiftedDcId shiftedDcId, uint64 keyId) {
void Instance::sendRequest( void Instance::sendRequest(
mtpRequestId requestId, mtpRequestId requestId,
SerializedRequest &&request, SerializedRequest &&request,
RPCResponseHandler &&callbacks, ResponseHandler &&callbacks,
ShiftedDcId shiftedDcId, ShiftedDcId shiftedDcId,
crl::time msCanWait, crl::time msCanWait,
bool needsLayer, bool needsLayer,

View file

@ -102,21 +102,26 @@ public:
void reInitConnection(DcId dcId); void reInitConnection(DcId dcId);
void logout(Fn<void()> done); void logout(Fn<void()> done);
void setUpdatesHandler(RPCDoneHandlerPtr onDone); void setUpdatesHandler(Fn<void(const Response&)> handler);
void setGlobalFailHandler(RPCFailHandlerPtr onFail); void setGlobalFailHandler(
void setStateChangedHandler(Fn<void(ShiftedDcId shiftedDcId, int32 state)> handler); Fn<void(const RPCError&, const Response&)> handler);
void setStateChangedHandler(
Fn<void(ShiftedDcId shiftedDcId, int32 state)> handler);
void setSessionResetHandler(Fn<void(ShiftedDcId shiftedDcId)> handler); void setSessionResetHandler(Fn<void(ShiftedDcId shiftedDcId)> handler);
void clearGlobalHandlers(); void clearGlobalHandlers();
void onStateChange(ShiftedDcId shiftedDcId, int32 state); void onStateChange(ShiftedDcId shiftedDcId, int32 state);
void onSessionReset(ShiftedDcId shiftedDcId); void onSessionReset(ShiftedDcId shiftedDcId);
void execCallback(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end); [[nodiscard]] bool hasCallback(mtpRequestId requestId) const;
bool hasCallbacks(mtpRequestId requestId); void processCallback(const Response &response);
void globalCallback(const mtpPrime *from, const mtpPrime *end); void processUpdate(const Response &message);
// return true if need to clean request data // return true if need to clean request data
bool rpcErrorOccured(mtpRequestId requestId, const RPCFailHandlerPtr &onFail, const RPCError &err); bool rpcErrorOccured(
const Response &response,
const FailHandler &onFail,
const RPCError &err);
// Thread-safe. // Thread-safe.
bool isKeysDestroyer() const; bool isKeysDestroyer() const;
@ -141,7 +146,7 @@ public:
template <typename Request> template <typename Request>
mtpRequestId send( mtpRequestId send(
const Request &request, const Request &request,
RPCResponseHandler &&callbacks = {}, ResponseHandler &&callbacks = {},
ShiftedDcId shiftedDcId = 0, ShiftedDcId shiftedDcId = 0,
crl::time msCanWait = 0, crl::time msCanWait = 0,
mtpRequestId afterRequestId = 0) { mtpRequestId afterRequestId = 0) {
@ -159,14 +164,14 @@ public:
template <typename Request> template <typename Request>
mtpRequestId send( mtpRequestId send(
const Request &request, const Request &request,
RPCDoneHandlerPtr &&onDone, DoneHandler &&onDone,
RPCFailHandlerPtr &&onFail = nullptr, FailHandler &&onFail = nullptr,
ShiftedDcId shiftedDcId = 0, ShiftedDcId shiftedDcId = 0,
crl::time msCanWait = 0, crl::time msCanWait = 0,
mtpRequestId afterRequestId = 0) { mtpRequestId afterRequestId = 0) {
return send( return send(
request, request,
RPCResponseHandler(std::move(onDone), std::move(onFail)), ResponseHandler{ std::move(onDone), std::move(onFail) },
shiftedDcId, shiftedDcId,
msCanWait, msCanWait,
afterRequestId); afterRequestId);
@ -191,7 +196,7 @@ public:
void sendSerialized( void sendSerialized(
mtpRequestId requestId, mtpRequestId requestId,
details::SerializedRequest &&request, details::SerializedRequest &&request,
RPCResponseHandler &&callbacks, ResponseHandler &&callbacks,
ShiftedDcId shiftedDcId, ShiftedDcId shiftedDcId,
crl::time msCanWait, crl::time msCanWait,
mtpRequestId afterRequestId) { mtpRequestId afterRequestId) {
@ -218,7 +223,7 @@ private:
void sendRequest( void sendRequest(
mtpRequestId requestId, mtpRequestId requestId,
details::SerializedRequest &&request, details::SerializedRequest &&request,
RPCResponseHandler &&callbacks, ResponseHandler &&callbacks,
ShiftedDcId shiftedDcId, ShiftedDcId shiftedDcId,
crl::time msCanWait, crl::time msCanWait,
bool needsLayer, bool needsLayer,

View file

@ -13,90 +13,61 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
namespace MTP { namespace MTP {
class ConcurrentSender::RPCDoneHandler : public RPCAbstractDoneHandler { class ConcurrentSender::HandlerMaker final {
public: public:
RPCDoneHandler( static ::MTP::DoneHandler MakeDone(
not_null<ConcurrentSender*> sender, not_null<ConcurrentSender*> sender,
Fn<void(FnMut<void()>)> runner); Fn<void(FnMut<void()>)> runner);
static ::MTP::FailHandler MakeFail(
bool operator()(
mtpRequestId requestId,
const mtpPrime *from,
const mtpPrime *end) override;
private:
base::weak_ptr<ConcurrentSender> _weak;
Fn<void(FnMut<void()>)> _runner;
};
class ConcurrentSender::RPCFailHandler : public RPCAbstractFailHandler {
public:
RPCFailHandler(
not_null<ConcurrentSender*> sender, not_null<ConcurrentSender*> sender,
Fn<void(FnMut<void()>)> runner, Fn<void(FnMut<void()>)> runner,
FailSkipPolicy skipPolicy); FailSkipPolicy skipPolicy);
bool operator()(
mtpRequestId requestId,
const RPCError &error) override;
private:
base::weak_ptr<ConcurrentSender> _weak;
Fn<void(FnMut<void()>)> _runner;
FailSkipPolicy _skipPolicy = FailSkipPolicy::Simple;
}; };
ConcurrentSender::RPCDoneHandler::RPCDoneHandler( ::MTP::DoneHandler ConcurrentSender::HandlerMaker::MakeDone(
not_null<ConcurrentSender*> sender, not_null<ConcurrentSender*> sender,
Fn<void(FnMut<void()>)> runner) Fn<void(FnMut<void()>)> runner) {
: _weak(sender) return [
, _runner(std::move(runner)) { weak = base::make_weak(sender.get()),
runner = std::move(runner)
](const Response &response) mutable {
runner([=]() mutable {
if (const auto strong = weak.get()) {
strong->senderRequestDone(
response.requestId,
bytes::make_span(response.reply));
}
});
return true;
};
} }
bool ConcurrentSender::RPCDoneHandler::operator()( ::MTP::FailHandler ConcurrentSender::HandlerMaker::MakeFail(
mtpRequestId requestId, not_null<ConcurrentSender*> sender,
const mtpPrime *from, Fn<void(FnMut<void()>)> runner,
const mtpPrime *end) { FailSkipPolicy skipPolicy) {
auto response = gsl::make_span( return [
from, weak = base::make_weak(sender.get()),
end - from); runner = std::move(runner),
_runner([=, weak = _weak, moved = bytes::make_vector(response)]() mutable { skipPolicy
if (const auto strong = weak.get()) { ](const RPCError &error, const Response &response) mutable {
strong->senderRequestDone(requestId, std::move(moved)); if (skipPolicy == FailSkipPolicy::Simple) {
if (MTP::isDefaultHandledError(error)) {
return false;
}
} else if (skipPolicy == FailSkipPolicy::HandleFlood) {
if (MTP::isDefaultHandledError(error)
&& !MTP::isFloodError(error)) {
return false;
}
} }
}); runner([=, requestId = response.requestId]() mutable {
return true; if (const auto strong = weak.get()) {
} strong->senderRequestFail(requestId, error);
}
ConcurrentSender::RPCFailHandler::RPCFailHandler( });
not_null<ConcurrentSender*> sender, return true;
Fn<void(FnMut<void()>)> runner, };
FailSkipPolicy skipPolicy)
: _weak(sender)
, _runner(std::move(runner))
, _skipPolicy(skipPolicy) {
}
bool ConcurrentSender::RPCFailHandler::operator()(
mtpRequestId requestId,
const RPCError &error) {
if (_skipPolicy == FailSkipPolicy::Simple) {
if (MTP::isDefaultHandledError(error)) {
return false;
}
} else if (_skipPolicy == FailSkipPolicy::HandleFlood) {
if (MTP::isDefaultHandledError(error) && !MTP::isFloodError(error)) {
return false;
}
}
_runner([=, weak = _weak, error = error]() mutable {
if (const auto strong = weak.get()) {
strong->senderRequestFail(requestId, std::move(error));
}
});
return true;
} }
template <typename Method> template <typename Method>
@ -147,8 +118,8 @@ mtpRequestId ConcurrentSender::RequestBuilder::send() {
_sender->with_instance([ _sender->with_instance([
=, =,
request = std::move(_serialized), request = std::move(_serialized),
done = std::make_shared<RPCDoneHandler>(_sender, _sender->_runner), done = HandlerMaker::MakeDone(_sender, _sender->_runner),
fail = std::make_shared<RPCFailHandler>( fail = HandlerMaker::MakeFail(
_sender, _sender,
_sender->_runner, _sender->_runner,
_failSkipPolicy) _failSkipPolicy)
@ -156,7 +127,7 @@ mtpRequestId ConcurrentSender::RequestBuilder::send() {
instance->sendSerialized( instance->sendSerialized(
requestId, requestId,
std::move(request), std::move(request),
RPCResponseHandler(std::move(done), std::move(fail)), ResponseHandler{ std::move(done), std::move(fail) },
dcId, dcId,
msCanWait, msCanWait,
afterRequestId); afterRequestId);
@ -198,9 +169,9 @@ void ConcurrentSender::senderRequestDone(
void ConcurrentSender::senderRequestFail( void ConcurrentSender::senderRequestFail(
mtpRequestId requestId, mtpRequestId requestId,
RPCError &&error) { const RPCError &error) {
if (auto handlers = _requests.take(requestId)) { if (auto handlers = _requests.take(requestId)) {
handlers->fail(requestId, std::move(error)); handlers->fail(requestId, error);
} }
} }

View file

@ -40,7 +40,7 @@ class ConcurrentSender : public base::has_weak_ptr {
bytes::const_span result)>; bytes::const_span result)>;
using FailHandler = FnMut<void( using FailHandler = FnMut<void(
mtpRequestId requestId, mtpRequestId requestId,
RPCError &&error)>; const RPCError &error)>;
struct Handlers { struct Handlers {
DoneHandler done; DoneHandler done;
FailHandler fail; FailHandler fail;
@ -95,7 +95,7 @@ public:
template <typename Request> template <typename Request>
class SpecificRequestBuilder : public RequestBuilder { class SpecificRequestBuilder : public RequestBuilder {
public: public:
using Response = typename Request::ResponseType; using Result = typename Request::ResponseType;
SpecificRequestBuilder( SpecificRequestBuilder(
const SpecificRequestBuilder &other) = delete; const SpecificRequestBuilder &other) = delete;
@ -115,18 +115,14 @@ public:
// Allow code completion to show response type. // Allow code completion to show response type.
[[nodiscard]] SpecificRequestBuilder &done(FnMut<void()> &&handler); [[nodiscard]] SpecificRequestBuilder &done(FnMut<void()> &&handler);
[[nodiscard]] SpecificRequestBuilder &done( [[nodiscard]] SpecificRequestBuilder &done(
FnMut<void(mtpRequestId)> &&handler); FnMut<void(mtpRequestId, Result &&)> &&handler);
[[nodiscard]] SpecificRequestBuilder &done( [[nodiscard]] SpecificRequestBuilder &done(
FnMut<void(mtpRequestId, Response &&)> &&handler); FnMut<void(Result &&)> &&handler);
[[nodiscard]] SpecificRequestBuilder &done(FnMut<void( [[nodiscard]] SpecificRequestBuilder &fail(Fn<void()> &&handler);
Response &&)> &&handler);
[[nodiscard]] SpecificRequestBuilder &fail(FnMut<void()> &&handler);
[[nodiscard]] SpecificRequestBuilder &fail( [[nodiscard]] SpecificRequestBuilder &fail(
FnMut<void(mtpRequestId)> &&handler); Fn<void(mtpRequestId, const RPCError &)> &&handler);
[[nodiscard]] SpecificRequestBuilder &fail( [[nodiscard]] SpecificRequestBuilder &fail(
FnMut<void(mtpRequestId, RPCError &&)> &&handler); Fn<void(const RPCError &)> &&handler);
[[nodiscard]] SpecificRequestBuilder &fail(
FnMut<void(RPCError &&)> &&handler);
#else // !MTP_SENDER_USE_GENERIC_HANDLERS #else // !MTP_SENDER_USE_GENERIC_HANDLERS
template <typename Handler> template <typename Handler>
[[nodiscard]] SpecificRequestBuilder &done(Handler &&handler); [[nodiscard]] SpecificRequestBuilder &done(Handler &&handler);
@ -178,10 +174,8 @@ public:
~ConcurrentSender(); ~ConcurrentSender();
private: private:
class RPCDoneHandler; class HandlerMaker;
friend class RPCDoneHandler; friend class HandlerMaker;
class RPCFailHandler;
friend class RPCFailHandler;
friend class RequestBuilder; friend class RequestBuilder;
friend class SentRequestWrap; friend class SentRequestWrap;
@ -191,7 +185,7 @@ private:
bytes::const_span result); bytes::const_span result);
void senderRequestFail( void senderRequestFail(
mtpRequestId requestId, mtpRequestId requestId,
RPCError &&error); const RPCError &error);
void senderRequestCancel(mtpRequestId requestId); void senderRequestCancel(mtpRequestId requestId);
void senderRequestCancelAll(); void senderRequestCancelAll();
void senderRequestDetach(mtpRequestId requestId); void senderRequestDetach(mtpRequestId requestId);
@ -202,7 +196,7 @@ private:
}; };
template <typename Response, typename InvokeFullDone> template <typename Result, typename InvokeFullDone>
void ConcurrentSender::RequestBuilder::setDoneHandler( void ConcurrentSender::RequestBuilder::setDoneHandler(
InvokeFullDone &&invoke InvokeFullDone &&invoke
) noexcept { ) noexcept {
@ -211,11 +205,11 @@ void ConcurrentSender::RequestBuilder::setDoneHandler(
bytes::const_span result) mutable { bytes::const_span result) mutable {
auto from = reinterpret_cast<const mtpPrime*>(result.data()); auto from = reinterpret_cast<const mtpPrime*>(result.data());
const auto end = from + result.size() / sizeof(mtpPrime); const auto end = from + result.size() / sizeof(mtpPrime);
Response data; Result data;
if (!data.read(from, end)) { if (!data.read(from, end)) {
return false; return false;
} }
std::move(handler)(requestId, std::move(data)); handler(requestId, std::move(data));
return true; return true;
}; };
} }
@ -254,33 +248,21 @@ auto ConcurrentSender::SpecificRequestBuilder<Request>::afterDelay(
// Allow code completion to show response type. // Allow code completion to show response type.
template <typename Request> template <typename Request>
auto ConcurrentSender::SpecificRequestBuilder<Request>::done( auto ConcurrentSender::SpecificRequestBuilder<Request>::done(
FnMut<void(Response &&)> &&handler FnMut<void(Result &&)> &&handler
) -> SpecificRequestBuilder & { ) -> SpecificRequestBuilder & {
setDoneHandler<Response>([handler = std::move(handler)]( setDoneHandler<Result>([handler = std::move(handler)](
mtpRequestId requestId, mtpRequestId requestId,
Response &&result) mutable { Result &&result) mutable {
std::move(handler)(std::move(result)); handler(std::move(result));
}); });
return *this; return *this;
} }
template <typename Request> template <typename Request>
auto ConcurrentSender::SpecificRequestBuilder<Request>::done( auto ConcurrentSender::SpecificRequestBuilder<Request>::done(
FnMut<void(mtpRequestId, Response &&)> &&handler FnMut<void(mtpRequestId, Result &&)> &&handler
) -> SpecificRequestBuilder & { ) -> SpecificRequestBuilder & {
setDoneHandler<Response>(std::move(handler)); setDoneHandler<Result>(std::move(handler));
return *this;
}
template <typename Request>
auto ConcurrentSender::SpecificRequestBuilder<Request>::done(
FnMut<void(mtpRequestId)> &&handler
) -> SpecificRequestBuilder & {
setDoneHandler<Response>([handler = std::move(handler)](
mtpRequestId requestId,
Response &&result) mutable {
std::move(handler)(requestId);
});
return *this; return *this;
} }
@ -288,9 +270,9 @@ template <typename Request>
auto ConcurrentSender::SpecificRequestBuilder<Request>::done( auto ConcurrentSender::SpecificRequestBuilder<Request>::done(
FnMut<void()> &&handler FnMut<void()> &&handler
) -> SpecificRequestBuilder & { ) -> SpecificRequestBuilder & {
setDoneHandler<Response>([handler = std::move(handler)]( setDoneHandler<Result>([handler = std::move(handler)](
mtpRequestId requestId, mtpRequestId requestId,
Response &&result) mutable { Result &&result) mutable {
std::move(handler)(); std::move(handler)();
}); });
return *this; return *this;
@ -298,19 +280,19 @@ auto ConcurrentSender::SpecificRequestBuilder<Request>::done(
template <typename Request> template <typename Request>
auto ConcurrentSender::SpecificRequestBuilder<Request>::fail( auto ConcurrentSender::SpecificRequestBuilder<Request>::fail(
FnMut<void(RPCError &&)> &&handler Fn<void(const RPCError &)> &&handler
) -> SpecificRequestBuilder & { ) -> SpecificRequestBuilder & {
setFailHandler([handler = std::move(handler)]( setFailHandler([handler = std::move(handler)](
mtpRequestId requestId, mtpRequestId requestId,
RPCError &&error) mutable { const RPCError &error) {
std::move(handler)(std::move(error)); handler(error);
}); });
return *this; return *this;
} }
template <typename Request> template <typename Request>
auto ConcurrentSender::SpecificRequestBuilder<Request>::fail( auto ConcurrentSender::SpecificRequestBuilder<Request>::fail(
FnMut<void(mtpRequestId, RPCError &&)> &&handler Fn<void(mtpRequestId, const RPCError &)> &&handler
) -> SpecificRequestBuilder & { ) -> SpecificRequestBuilder & {
setFailHandler(std::move(handler)); setFailHandler(std::move(handler));
return *this; return *this;
@ -318,24 +300,12 @@ auto ConcurrentSender::SpecificRequestBuilder<Request>::fail(
template <typename Request> template <typename Request>
auto ConcurrentSender::SpecificRequestBuilder<Request>::fail( auto ConcurrentSender::SpecificRequestBuilder<Request>::fail(
FnMut<void(mtpRequestId)> &&handler Fn<void()> &&handler
) -> SpecificRequestBuilder & { ) -> SpecificRequestBuilder & {
setFailHandler([handler = std::move(handler)]( setFailHandler([handler = std::move(handler)](
mtpRequestId requestId, mtpRequestId requestId,
RPCError &&error) mutable { const RPCError &error) {
std::move(handler)(requestId); handler();
});
return *this;
}
template <typename Request>
auto ConcurrentSender::SpecificRequestBuilder<Request>::fail(
FnMut<void()> &&handler
) -> SpecificRequestBuilder & {
setFailHandler([handler = std::move(handler)](
mtpRequestId requestId,
RPCError &&error) mutable {
std::move(handler)();
}); });
return *this; return *this;
} }
@ -345,38 +315,29 @@ template <typename Handler>
auto ConcurrentSender::SpecificRequestBuilder<Request>::done( auto ConcurrentSender::SpecificRequestBuilder<Request>::done(
Handler &&handler Handler &&handler
) -> SpecificRequestBuilder & { ) -> SpecificRequestBuilder & {
using Response = typename Request::ResponseType; using Result = typename Request::ResponseType;
constexpr auto takesFull = rpl::details::is_callable_plain_v< constexpr auto takesFull = rpl::details::is_callable_plain_v<
Handler, Handler,
mtpRequestId, mtpRequestId,
Response>; Result>;
constexpr auto takesResponse = rpl::details::is_callable_plain_v< constexpr auto takesResponse = rpl::details::is_callable_plain_v<
Handler, Handler,
Response>; Result>;
constexpr auto takesRequestId = rpl::details::is_callable_plain_v<
Handler,
mtpRequestId>;
constexpr auto takesNone = rpl::details::is_callable_plain_v<Handler>; constexpr auto takesNone = rpl::details::is_callable_plain_v<Handler>;
if constexpr (takesFull) { if constexpr (takesFull) {
setDoneHandler<Response>(std::forward<Handler>(handler)); setDoneHandler<Result>(std::forward<Handler>(handler));
} else if constexpr (takesResponse) { } else if constexpr (takesResponse) {
setDoneHandler<Response>([handler = std::forward<Handler>(handler)]( setDoneHandler<Result>([handler = std::forward<Handler>(handler)](
mtpRequestId requestId, mtpRequestId requestId,
Response &&result) mutable { Result &&result) mutable {
std::move(handler)(std::move(result)); handler(std::move(result));
});
} else if constexpr (takesRequestId) {
setDoneHandler<Response>([handler = std::forward<Handler>(handler)](
mtpRequestId requestId,
Response &&result) mutable {
std::move(handler)(requestId);
}); });
} else if constexpr (takesNone) { } else if constexpr (takesNone) {
setDoneHandler<Response>([handler = std::forward<Handler>(handler)]( setDoneHandler<Result>([handler = std::forward<Handler>(handler)](
mtpRequestId requestId, mtpRequestId requestId,
Response &&result) mutable { Result &&result) mutable {
std::move(handler)(); handler();
}); });
} else { } else {
static_assert(false_t(Handler{}), "Bad done handler."); static_assert(false_t(Handler{}), "Bad done handler.");
@ -396,9 +357,6 @@ auto ConcurrentSender::SpecificRequestBuilder<Request>::fail(
constexpr auto takesError = rpl::details::is_callable_plain_v< constexpr auto takesError = rpl::details::is_callable_plain_v<
Handler, Handler,
RPCError>; RPCError>;
constexpr auto takesRequestId = rpl::details::is_callable_plain_v<
Handler,
mtpRequestId>;
constexpr auto takesNone = rpl::details::is_callable_plain_v<Handler>; constexpr auto takesNone = rpl::details::is_callable_plain_v<Handler>;
if constexpr (takesFull) { if constexpr (takesFull) {
@ -406,20 +364,14 @@ auto ConcurrentSender::SpecificRequestBuilder<Request>::fail(
} else if constexpr (takesError) { } else if constexpr (takesError) {
setFailHandler([handler = std::forward<Handler>(handler)]( setFailHandler([handler = std::forward<Handler>(handler)](
mtpRequestId requestId, mtpRequestId requestId,
RPCError &&error) mutable { const RPCError &error) {
std::move(handler)(std::move(error)); handler(error);
});
} else if constexpr (takesRequestId) {
setFailHandler([handler = std::forward<Handler>(handler)](
mtpRequestId requestId,
RPCError &&error) mutable {
std::move(handler)(requestId);
}); });
} else if constexpr (takesNone) { } else if constexpr (takesNone) {
setFailHandler([handler = std::forward<Handler>(handler)]( setFailHandler([handler = std::forward<Handler>(handler)](
mtpRequestId requestId, mtpRequestId requestId,
RPCError &&error) mutable { const RPCError &error) {
std::move(handler)(); handler();
}); });
} else { } else {
static_assert(false_t(Handler{}), "Bad fail handler."); static_assert(false_t(Handler{}), "Bad fail handler.");

View file

@ -9,6 +9,20 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include <QtCore/QRegularExpression> #include <QtCore/QRegularExpression>
namespace {
[[nodiscard]] MTPrpcError ParseError(const mtpBuffer &reply) {
auto result = MTPRpcError();
auto from = reply.constData();
return result.read(from, from + reply.size())
? result
: RPCError::MTPLocal(
"RESPONSE_PARSE_FAILED",
"Error parse failed.");
}
} // namespace
RPCError::RPCError(const MTPrpcError &error) RPCError::RPCError(const MTPrpcError &error)
: _code(error.c_rpc_error().verror_code().v) { : _code(error.c_rpc_error().verror_code().v) {
QString text = qs(error.c_rpc_error().verror_message()); QString text = qs(error.c_rpc_error().verror_message());
@ -30,3 +44,37 @@ RPCError::RPCError(const MTPrpcError &error)
} }
} }
} }
RPCError::RPCError(const mtpBuffer &reply) : RPCError(ParseError(reply)) {
}
int32 RPCError::code() const {
return _code;
}
const QString &RPCError::type() const {
return _type;
}
const QString &RPCError::description() const {
return _description;
}
MTPrpcError RPCError::MTPLocal(
const QString &type,
const QString &description) {
return MTP_rpc_error(
MTP_int(0),
MTP_bytes(
("CLIENT_"
+ type
+ (description.length()
? (": " + description)
: QString())).toUtf8()));
}
RPCError RPCError::Local(
const QString &type,
const QString &description) {
return RPCError(MTPLocal(type, description));
}

View file

@ -11,38 +11,27 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
class RPCError { class RPCError {
public: public:
RPCError(const MTPrpcError &error); explicit RPCError(const MTPrpcError &error);
explicit RPCError(const mtpBuffer &reply);
int32 code() const {
return _code;
}
const QString &type() const {
return _type;
}
const QString &description() const {
return _description;
}
enum { enum {
NoError, NoError,
TimeoutError TimeoutError
}; };
static RPCError Local(const QString &type, const QString &description) { [[nodiscard]] int32 code() const;
return MTP_rpc_error( [[nodiscard]] const QString &type() const;
MTP_int(0), [[nodiscard]] const QString &description() const;
MTP_bytes(
("CLIENT_" [[nodiscard]] static RPCError Local(
+ type const QString &type,
+ (description.length() const QString &description);
? (": " + description) [[nodiscard]] static MTPrpcError MTPLocal(
: QString())).toUtf8())); const QString &type,
} const QString &description);
private: private:
int32 _code; int32 _code = 0;
QString _type, _description; QString _type, _description;
}; };
@ -61,515 +50,18 @@ inline bool isDefaultHandledError(const RPCError &error) {
return isTemporaryError(error); return isTemporaryError(error);
} }
} // namespace MTP struct Response {
mtpBuffer reply;
class RPCAbstractDoneHandler { // abstract done mtpMsgId outerMsgId = 0;
public:
[[nodiscard]] virtual bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) = 0;
virtual ~RPCAbstractDoneHandler() {
}
};
using RPCDoneHandlerPtr = std::shared_ptr<RPCAbstractDoneHandler>;
class RPCAbstractFailHandler { // abstract fail
public:
virtual bool operator()(mtpRequestId requestId, const RPCError &e) = 0;
virtual ~RPCAbstractFailHandler() {
}
};
using RPCFailHandlerPtr = std::shared_ptr<RPCAbstractFailHandler>;
struct RPCResponseHandler {
RPCResponseHandler() = default;
RPCResponseHandler(RPCDoneHandlerPtr &&done, RPCFailHandlerPtr &&fail)
: onDone(std::move(done))
, onFail(std::move(fail)) {
}
RPCDoneHandlerPtr onDone;
RPCFailHandlerPtr onFail;
};
class RPCDoneHandlerBare : public RPCAbstractDoneHandler { // done(from, end)
using CallbackType = bool (*)(const mtpPrime *, const mtpPrime *);
public:
RPCDoneHandlerBare(CallbackType onDone) : _onDone(onDone) {
}
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
return (*_onDone)(from, end);
}
private:
CallbackType _onDone;
};
class RPCDoneHandlerBareReq : public RPCAbstractDoneHandler { // done(from, end, req_id)
using CallbackType = bool (*)(const mtpPrime *, const mtpPrime *, mtpRequestId);
public:
RPCDoneHandlerBareReq(CallbackType onDone) : _onDone(onDone) {
}
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
return (*_onDone)(from, end, requestId);
}
private:
CallbackType _onDone;
};
template <typename TReturn, typename TResponse>
class RPCDoneHandlerPlain : public RPCAbstractDoneHandler { // done(result)
using CallbackType = TReturn (*)(const TResponse &);
public:
RPCDoneHandlerPlain(CallbackType onDone) : _onDone(onDone) {
}
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
auto response = TResponse();
if (!response.read(from, end)) {
return false;
}
(*_onDone)(std::move(response));
return true;
}
private:
CallbackType _onDone;
};
template <typename TReturn, typename TResponse>
class RPCDoneHandlerReq : public RPCAbstractDoneHandler { // done(result, req_id)
using CallbackType = TReturn (*)(const TResponse &, mtpRequestId);
public:
RPCDoneHandlerReq(CallbackType onDone) : _onDone(onDone) {
}
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
auto response = TResponse();
if (!response.read(from, end)) {
return false;
}
(*_onDone)(std::move(response), requestId);
return true;
}
private:
CallbackType _onDone;
};
template <typename TReturn>
class RPCDoneHandlerNo : public RPCAbstractDoneHandler { // done()
using CallbackType = TReturn (*)();
public:
RPCDoneHandlerNo(CallbackType onDone) : _onDone(onDone) {
}
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
(*_onDone)();
return true;
}
private:
CallbackType _onDone;
};
template <typename TReturn>
class RPCDoneHandlerNoReq : public RPCAbstractDoneHandler { // done(req_id)
using CallbackType = TReturn (*)(mtpRequestId);
public:
RPCDoneHandlerNoReq(CallbackType onDone) : _onDone(onDone) {
}
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
(*_onDone)(requestId);
return true;
}
private:
CallbackType _onDone;
};
class RPCFailHandlerPlain : public RPCAbstractFailHandler { // fail(error)
using CallbackType = bool (*)(const RPCError &);
public:
RPCFailHandlerPlain(CallbackType onFail) : _onFail(onFail) {
}
bool operator()(mtpRequestId requestId, const RPCError &e) override {
return (*_onFail)(e);
}
private:
CallbackType _onFail;
};
class RPCFailHandlerReq : public RPCAbstractFailHandler { // fail(error, req_id)
using CallbackType = bool (*)(const RPCError &, mtpRequestId);
public:
RPCFailHandlerReq(CallbackType onFail) : _onFail(onFail) {
}
bool operator()(mtpRequestId requestId, const RPCError &e) override {
return (*_onFail)(e, requestId);
}
private:
CallbackType _onFail;
};
class RPCFailHandlerNo : public RPCAbstractFailHandler { // fail()
using CallbackType = bool (*)();
public:
RPCFailHandlerNo(CallbackType onFail) : _onFail(onFail) {
}
bool operator()(mtpRequestId requestId, const RPCError &e) override {
return (*_onFail)();
}
private:
CallbackType _onFail;
};
class RPCFailHandlerNoReq : public RPCAbstractFailHandler { // fail(req_id)
using CallbackType = bool (*)(mtpRequestId);
public:
RPCFailHandlerNoReq(CallbackType onFail) : _onFail(onFail) {
}
bool operator()(mtpRequestId requestId, const RPCError &e) override {
return (*_onFail)(requestId);
}
private:
CallbackType _onFail;
};
struct RPCCallbackClear {
RPCCallbackClear(mtpRequestId id, int32 code = RPCError::NoError)
: requestId(id)
, errorCode(code) {
}
mtpRequestId requestId = 0; mtpRequestId requestId = 0;
int32 errorCode = 0;
}; };
inline RPCDoneHandlerPtr rpcDone(bool (*onDone)(const mtpPrime *, const mtpPrime *)) { // done(from, end) using DoneHandler = FnMut<bool(const Response&)>;
return RPCDoneHandlerPtr(new RPCDoneHandlerBare(onDone)); using FailHandler = Fn<bool(const RPCError&, const Response&)>;
}
inline RPCDoneHandlerPtr rpcDone(bool (*onDone)(const mtpPrime *, const mtpPrime *, mtpRequestId)) { // done(from, end, req_id)
return RPCDoneHandlerPtr(new RPCDoneHandlerBareReq(onDone));
}
template <typename TReturn, typename TResponse>
inline RPCDoneHandlerPtr rpcDone(TReturn (*onDone)(const TResponse &)) { // done(result)
return RPCDoneHandlerPtr(new RPCDoneHandlerPlain<TReturn, TResponse>(onDone));
}
template <typename TReturn, typename TResponse>
inline RPCDoneHandlerPtr rpcDone(TReturn (*onDone)(const TResponse &, mtpRequestId)) { // done(result, req_id)
return RPCDoneHandlerPtr(new RPCDoneHandlerReq<TReturn, TResponse>(onDone));
}
template <typename TReturn>
inline RPCDoneHandlerPtr rpcDone(TReturn (*onDone)()) { // done()
return RPCDoneHandlerPtr(new RPCDoneHandlerNo<TReturn>(onDone));
}
template <typename TReturn>
inline RPCDoneHandlerPtr rpcDone(TReturn (*onDone)(mtpRequestId)) { // done(req_id)
return RPCDoneHandlerPtr(new RPCDoneHandlerNoReq<TReturn>(onDone));
}
inline RPCFailHandlerPtr rpcFail(bool (*onFail)(const RPCError &)) { // fail(error)
return RPCFailHandlerPtr(new RPCFailHandlerPlain(onFail));
}
inline RPCFailHandlerPtr rpcFail(bool (*onFail)(const RPCError &, mtpRequestId)) { // fail(error, req_id)
return RPCFailHandlerPtr(new RPCFailHandlerReq(onFail));
}
inline RPCFailHandlerPtr rpcFail(bool (*onFail)()) { // fail()
return RPCFailHandlerPtr(new RPCFailHandlerNo(onFail));
}
inline RPCFailHandlerPtr rpcFail(bool (*onFail)(mtpRequestId)) { // fail(req_id)
return RPCFailHandlerPtr(new RPCFailHandlerNoReq(onFail));
}
using MTPStateChangedHandler = void (*)(int32 dcId, int32 state);
using MTPSessionResetHandler = void (*)(int32 dcId);
template <typename Base, typename FunctionType>
class RPCHandlerImplementation : public Base {
protected:
using Lambda = FnMut<FunctionType>;
using Parent = RPCHandlerImplementation<Base, FunctionType>;
public:
RPCHandlerImplementation(Lambda handler) : _handler(std::move(handler)) {
}
protected:
Lambda _handler;
struct ResponseHandler {
DoneHandler done;
FailHandler fail;
}; };
template <typename FunctionType> } // namespace MTP
using RPCDoneHandlerImplementation = RPCHandlerImplementation<RPCAbstractDoneHandler, FunctionType>;
class RPCDoneHandlerImplementationBare : public RPCDoneHandlerImplementation<bool(const mtpPrime*, const mtpPrime*)> { // done(from, end)
public:
using RPCDoneHandlerImplementation<bool(const mtpPrime*, const mtpPrime*)>::Parent::Parent;
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
return this->_handler ? this->_handler(from, end) : true;
}
};
class RPCDoneHandlerImplementationBareReq : public RPCDoneHandlerImplementation<bool(const mtpPrime*, const mtpPrime*, mtpRequestId)> { // done(from, end, req_id)
public:
using RPCDoneHandlerImplementation<bool(const mtpPrime*, const mtpPrime*, mtpRequestId)>::Parent::Parent;
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
return this->_handler ? this->_handler(from, end, requestId) : true;
}
};
template <typename R, typename TResponse>
class RPCDoneHandlerImplementationPlain : public RPCDoneHandlerImplementation<R(const TResponse&)> { // done(result)
public:
using RPCDoneHandlerImplementation<R(const TResponse&)>::Parent::Parent;
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
auto response = TResponse();
if (!response.read(from, end)) {
return false;
}
if (this->_handler) {
this->_handler(std::move(response));
}
return true;
}
};
template <typename R, typename TResponse>
class RPCDoneHandlerImplementationReq : public RPCDoneHandlerImplementation<R(const TResponse&, mtpRequestId)> { // done(result, req_id)
public:
using RPCDoneHandlerImplementation<R(const TResponse&, mtpRequestId)>::Parent::Parent;
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
auto response = TResponse();
if (!response.read(from, end)) {
return false;
}
if (this->_handler) {
this->_handler(std::move(response), requestId);
}
return true;
}
};
template <typename R>
class RPCDoneHandlerImplementationNo : public RPCDoneHandlerImplementation<R()> { // done()
public:
using RPCDoneHandlerImplementation<R()>::Parent::Parent;
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
if (this->_handler) {
this->_handler();
}
return true;
}
};
template <typename R>
class RPCDoneHandlerImplementationNoReq : public RPCDoneHandlerImplementation<R(mtpRequestId)> { // done(req_id)
public:
using RPCDoneHandlerImplementation<R(mtpRequestId)>::Parent::Parent;
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
if (this->_handler) {
this->_handler(requestId);
}
return true;
}
};
template <typename Lambda>
constexpr bool rpcDone_canCallBare_v = rpl::details::is_callable_plain_v<
Lambda, const mtpPrime*, const mtpPrime*>;
template <typename Lambda>
constexpr bool rpcDone_canCallBareReq_v = rpl::details::is_callable_plain_v<
Lambda, const mtpPrime*, const mtpPrime*, mtpRequestId>;
template <typename Lambda>
constexpr bool rpcDone_canCallNo_v = rpl::details::is_callable_plain_v<
Lambda>;
template <typename Lambda>
constexpr bool rpcDone_canCallNoReq_v = rpl::details::is_callable_plain_v<
Lambda, mtpRequestId>;
template <typename Function>
struct rpcDone_canCallPlain : std::false_type {
};
template <typename Lambda, typename Return, typename T>
struct rpcDone_canCallPlain<Return(Lambda::*)(const T&)> : std::true_type {
using Arg = T;
};
template <typename Lambda, typename Return, typename T>
struct rpcDone_canCallPlain<Return(Lambda::*)(const T&)const>
: rpcDone_canCallPlain<Return(Lambda::*)(const T&)> {
};
template <typename Function>
constexpr bool rpcDone_canCallPlain_v = rpcDone_canCallPlain<Function>::value;
template <typename Function>
struct rpcDone_canCallReq : std::false_type {
};
template <typename Lambda, typename Return, typename T>
struct rpcDone_canCallReq<Return(Lambda::*)(const T&, mtpRequestId)> : std::true_type {
using Arg = T;
};
template <typename Lambda, typename Return, typename T>
struct rpcDone_canCallReq<Return(Lambda::*)(const T&, mtpRequestId)const>
: rpcDone_canCallReq<Return(Lambda::*)(const T&, mtpRequestId)> {
};
template <typename Function>
constexpr bool rpcDone_canCallReq_v = rpcDone_canCallReq<Function>::value;
template <typename Function>
struct rpcDone_returnType;
template <typename Lambda, typename Return, typename ...Args>
struct rpcDone_returnType<Return(Lambda::*)(Args...)> {
using type = Return;
};
template <typename Lambda, typename Return, typename ...Args>
struct rpcDone_returnType<Return(Lambda::*)(Args...)const> {
using type = Return;
};
template <typename Function>
using rpcDone_returnType_t = typename rpcDone_returnType<Function>::type;
template <
typename Lambda,
typename Function = crl::deduced_call_type<Lambda>>
RPCDoneHandlerPtr rpcDone(Lambda lambda) {
using R = rpcDone_returnType_t<Function>;
if constexpr (rpcDone_canCallBare_v<Lambda>) {
return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationBare(std::move(lambda)));
} else if constexpr (rpcDone_canCallBareReq_v<Lambda>) {
return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationBareReq(std::move(lambda)));
} else if constexpr (rpcDone_canCallNo_v<Lambda>) {
return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationNo<R>(std::move(lambda)));
} else if constexpr (rpcDone_canCallNoReq_v<Lambda>) {
return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationNoReq<R>(std::move(lambda)));
} else if constexpr (rpcDone_canCallPlain_v<Function>) {
using T = typename rpcDone_canCallPlain<Function>::Arg;
return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationPlain<R, T>(std::move(lambda)));
} else if constexpr (rpcDone_canCallReq_v<Function>) {
using T = typename rpcDone_canCallReq<Function>::Arg;
return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationReq<R, T>(std::move(lambda)));
} else {
static_assert(false_t(lambda), "Unknown method.");
}
}
template <typename FunctionType>
using RPCFailHandlerImplementation = RPCHandlerImplementation<RPCAbstractFailHandler, FunctionType>;
class RPCFailHandlerImplementationPlain : public RPCFailHandlerImplementation<bool(const RPCError&)> { // fail(error)
public:
using Parent::Parent;
bool operator()(mtpRequestId requestId, const RPCError &error) override {
return _handler ? _handler(error) : true;
}
};
class RPCFailHandlerImplementationReq : public RPCFailHandlerImplementation<bool(const RPCError&, mtpRequestId)> { // fail(error, req_id)
public:
using Parent::Parent;
bool operator()(mtpRequestId requestId, const RPCError &error) override {
return this->_handler ? this->_handler(error, requestId) : true;
}
};
class RPCFailHandlerImplementationNo : public RPCFailHandlerImplementation<bool()> { // fail()
public:
using Parent::Parent;
bool operator()(mtpRequestId requestId, const RPCError &error) override {
return this->_handler ? this->_handler() : true;
}
};
class RPCFailHandlerImplementationNoReq : public RPCFailHandlerImplementation<bool(mtpRequestId)> { // fail(req_id)
public:
using Parent::Parent;
bool operator()(mtpRequestId requestId, const RPCError &error) override {
return this->_handler ? this->_handler(requestId) : true;
}
};
template <typename Lambda>
constexpr bool rpcFail_canCallNo_v = rpl::details::is_callable_plain_v<
Lambda>;
template <typename Lambda>
constexpr bool rpcFail_canCallNoReq_v = rpl::details::is_callable_plain_v<
Lambda, mtpRequestId>;
template <typename Lambda>
constexpr bool rpcFail_canCallPlain_v = rpl::details::is_callable_plain_v<
Lambda, const RPCError&>;
template <typename Lambda>
constexpr bool rpcFail_canCallReq_v = rpl::details::is_callable_plain_v<
Lambda, const RPCError&, mtpRequestId>;
template <
typename Lambda,
typename Function = crl::deduced_call_type<Lambda>>
RPCFailHandlerPtr rpcFail(Lambda lambda) {
if constexpr (rpcFail_canCallNo_v<Lambda>) {
return RPCFailHandlerPtr(new RPCFailHandlerImplementationNo(std::move(lambda)));
} else if constexpr (rpcFail_canCallNoReq_v<Lambda>) {
return RPCFailHandlerPtr(new RPCFailHandlerImplementationNoReq(std::move(lambda)));
} else if constexpr (rpcFail_canCallPlain_v<Lambda>) {
return RPCFailHandlerPtr(new RPCFailHandlerImplementationPlain(std::move(lambda)));
} else if constexpr (rpcFail_canCallReq_v<Lambda>) {
return RPCFailHandlerPtr(new RPCFailHandlerImplementationReq(std::move(lambda)));
} else {
static_assert(false_t(lambda), "Unknown method.");
}
}

View file

@ -22,111 +22,108 @@ class Sender {
RequestBuilder &operator=(RequestBuilder &&other) = delete; RequestBuilder &operator=(RequestBuilder &&other) = delete;
protected: protected:
using FailPlainHandler = FnMut<void(const RPCError &error)>;
using FailRequestIdHandler = FnMut<void(const RPCError &error, mtpRequestId requestId)>;
enum class FailSkipPolicy { enum class FailSkipPolicy {
Simple, Simple,
HandleFlood, HandleFlood,
HandleAll, HandleAll,
}; };
template <typename Response> using FailPlainHandler = Fn<void()>;
struct DonePlainPolicy { using FailErrorHandler = Fn<void(const RPCError&)>;
using Callback = FnMut<void(const Response &result)>; using FailRequestIdHandler = Fn<void(const RPCError&, mtpRequestId)>;
static void handle(Callback &&handler, mtpRequestId requestId, Response &&result) { using FailFullHandler = Fn<void(const RPCError&, const Response&)>;
handler(result);
}
}; template <typename ...Args>
template <typename Response> static constexpr bool IsCallable
struct DoneRequestIdPolicy { = rpl::details::is_callable_plain_v<Args...>;
using Callback = FnMut<void(const Response &result, mtpRequestId requestId)>;
static void handle(Callback &&handler, mtpRequestId requestId, Response &&result) {
handler(result, requestId);
}
}; template <typename Result, typename Handler>
template <typename Response, template <typename> typename PolicyTemplate> [[nodiscard]] DoneHandler MakeDoneHandler(
class DoneHandler : public RPCAbstractDoneHandler { not_null<Sender*> sender,
using Policy = PolicyTemplate<Response>; Handler &&handler) {
using Callback = typename Policy::Callback; return [sender, handler = std::forward<Handler>(handler)](
const Response &response) mutable {
auto onstack = std::move(handler);
sender->senderRequestHandled(response.requestId);
public: auto result = Result();
DoneHandler(not_null<Sender*> sender, Callback handler) : _sender(sender), _handler(std::move(handler)) { auto from = response.reply.constData();
} if (!result.read(from, from + response.reply.size())) {
bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
auto handler = std::move(_handler);
_sender->senderRequestHandled(requestId);
auto result = Response();
if (!result.read(from, end)) {
return false; return false;
} } else if (!onstack) {
if (handler) { return true;
Policy::handle(std::move(handler), requestId, std::move(result)); } else if constexpr (IsCallable<
Handler,
const Result&,
const Response&>) {
onstack(result, response);
} else if constexpr (IsCallable<
Handler,
const Result&,
mtpRequestId>) {
onstack(result, response.requestId);
} else if constexpr (IsCallable<
Handler,
const Result&>) {
onstack(result);
} else if constexpr (IsCallable<Handler>) {
onstack();
} else {
static_assert(false_t(Handler{}), "Bad done handler.");
} }
return true; return true;
} };
}
private: template <typename Handler>
not_null<Sender*> _sender; [[nodiscard]] FailHandler MakeFailHandler(
Callback _handler; not_null<Sender*> sender,
Handler &&handler,
}; FailSkipPolicy skipPolicy) {
return [
struct FailPlainPolicy { sender,
using Callback = FnMut<void(const RPCError &error)>; handler = std::forward<Handler>(handler),
static void handle(Callback &&handler, mtpRequestId requestId, const RPCError &error) { skipPolicy
handler(error); ](const RPCError &error, const Response &response) {
} if (skipPolicy == FailSkipPolicy::Simple) {
};
struct FailRequestIdPolicy {
using Callback = FnMut<void(const RPCError &error, mtpRequestId requestId)>;
static void handle(Callback &&handler, mtpRequestId requestId, const RPCError &error) {
handler(error, requestId);
}
};
template <typename Policy>
class FailHandler : public RPCAbstractFailHandler {
using Callback = typename Policy::Callback;
public:
FailHandler(not_null<Sender*> sender, Callback handler, FailSkipPolicy skipPolicy)
: _sender(sender)
, _handler(std::move(handler))
, _skipPolicy(skipPolicy) {
}
bool operator()(mtpRequestId requestId, const RPCError &error) override {
if (_skipPolicy == FailSkipPolicy::Simple) {
if (isDefaultHandledError(error)) { if (isDefaultHandledError(error)) {
return false; return false;
} }
} else if (_skipPolicy == FailSkipPolicy::HandleFlood) { } else if (skipPolicy == FailSkipPolicy::HandleFlood) {
if (isDefaultHandledError(error) && !isFloodError(error)) { if (isDefaultHandledError(error) && !isFloodError(error)) {
return false; return false;
} }
} }
auto handler = std::move(_handler); auto onstack = handler;
_sender->senderRequestHandled(requestId); sender->senderRequestHandled(response.requestId);
if (handler) { if (!onstack) {
Policy::handle(std::move(handler), requestId, error); return true;
} else if constexpr (IsCallable<
Handler,
const RPCError&,
const Response&>) {
onstack(error, response);
} else if constexpr (IsCallable<
Handler,
const RPCError&,
mtpRequestId>) {
onstack(error, response.requestId);
} else if constexpr (IsCallable<
Handler,
const RPCError&>) {
onstack(error);
} else if constexpr (IsCallable<Handler>) {
onstack();
} else {
static_assert(false_t(Handler{}), "Bad fail handler.");
} }
return true; return true;
} };
}
private: explicit RequestBuilder(not_null<Sender*> sender) noexcept
not_null<Sender*> _sender; : _sender(sender) {
Callback _handler;
FailSkipPolicy _skipPolicy = FailSkipPolicy::Simple;
};
explicit RequestBuilder(not_null<Sender*> sender) noexcept : _sender(sender) {
} }
RequestBuilder(RequestBuilder &&other) = default; RequestBuilder(RequestBuilder &&other) = default;
@ -136,14 +133,12 @@ class Sender {
void setCanWait(crl::time ms) noexcept { void setCanWait(crl::time ms) noexcept {
_canWait = ms; _canWait = ms;
} }
void setDoneHandler(RPCDoneHandlerPtr &&handler) noexcept { void setDoneHandler(DoneHandler &&handler) noexcept {
_done = std::move(handler); _done = std::move(handler);
} }
void setFailHandler(FailPlainHandler &&handler) noexcept { template <typename Handler>
_fail = std::move(handler); void setFailHandler(Handler &&handler) noexcept {
} _fail = std::forward<Handler>(handler);
void setFailHandler(FailRequestIdHandler &&handler) noexcept {
_fail = std::move(handler);
} }
void setFailSkipPolicy(FailSkipPolicy policy) noexcept { void setFailSkipPolicy(FailSkipPolicy policy) noexcept {
_failSkipPolicy = policy; _failSkipPolicy = policy;
@ -158,18 +153,12 @@ class Sender {
crl::time takeCanWait() const noexcept { crl::time takeCanWait() const noexcept {
return _canWait; return _canWait;
} }
RPCDoneHandlerPtr takeOnDone() noexcept { DoneHandler takeOnDone() noexcept {
return std::move(_done); return std::move(_done);
} }
RPCFailHandlerPtr takeOnFail() { FailHandler takeOnFail() {
return v::match(_fail, [&](FailPlainHandler &value) return v::match(_fail, [&](auto &value) {
-> RPCFailHandlerPtr { return MakeFailHandler(
return std::make_shared<FailHandler<FailPlainPolicy>>(
_sender,
std::move(value),
_failSkipPolicy);
}, [&](FailRequestIdHandler &value) -> RPCFailHandlerPtr {
return std::make_shared<FailHandler<FailRequestIdPolicy>>(
_sender, _sender,
std::move(value), std::move(value),
_failSkipPolicy); _failSkipPolicy);
@ -190,8 +179,12 @@ class Sender {
not_null<Sender*> _sender; not_null<Sender*> _sender;
ShiftedDcId _dcId = 0; ShiftedDcId _dcId = 0;
crl::time _canWait = 0; crl::time _canWait = 0;
RPCDoneHandlerPtr _done; DoneHandler _done;
std::variant<FailPlainHandler, FailRequestIdHandler> _fail; std::variant<
FailPlainHandler,
FailErrorHandler,
FailRequestIdHandler,
FailFullHandler> _fail;
FailSkipPolicy _failSkipPolicy = FailSkipPolicy::Simple; FailSkipPolicy _failSkipPolicy = FailSkipPolicy::Simple;
mtpRequestId _afterRequestId = 0; mtpRequestId _afterRequestId = 0;
@ -210,7 +203,9 @@ public:
class SpecificRequestBuilder : public RequestBuilder { class SpecificRequestBuilder : public RequestBuilder {
private: private:
friend class Sender; friend class Sender;
SpecificRequestBuilder(not_null<Sender*> sender, Request &&request) noexcept : RequestBuilder(sender), _request(std::move(request)) { SpecificRequestBuilder(not_null<Sender*> sender, Request &&request) noexcept
: RequestBuilder(sender)
, _request(std::move(request)) {
} }
SpecificRequestBuilder(SpecificRequestBuilder &&other) = default; SpecificRequestBuilder(SpecificRequestBuilder &&other) = default;
@ -223,22 +218,63 @@ public:
setCanWait(ms); setCanWait(ms);
return *this; return *this;
} }
[[nodiscard]] SpecificRequestBuilder &done(FnMut<void(const typename Request::ResponseType &result)> callback) {
setDoneHandler(std::make_shared<DoneHandler<typename Request::ResponseType, DonePlainPolicy>>(sender(), std::move(callback))); using Result = typename Request::ResponseType;
[[nodiscard]] SpecificRequestBuilder &done(
FnMut<void(
const Result &result,
mtpRequestId requestId)> callback) {
setDoneHandler(
MakeDoneHandler<Result>(sender(), std::move(callback)));
return *this; return *this;
} }
[[nodiscard]] SpecificRequestBuilder &done(FnMut<void(const typename Request::ResponseType &result, mtpRequestId requestId)> callback) { [[nodiscard]] SpecificRequestBuilder &done(
setDoneHandler(std::make_shared<DoneHandler<typename Request::ResponseType, DoneRequestIdPolicy>>(sender(), std::move(callback))); FnMut<void(
const Result &result,
const Response &response)> callback) {
setDoneHandler(
MakeDoneHandler<Result>(sender(), std::move(callback)));
return *this; return *this;
} }
[[nodiscard]] SpecificRequestBuilder &fail(FnMut<void(const RPCError &error)> callback) noexcept { [[nodiscard]] SpecificRequestBuilder &done(
FnMut<void()> callback) {
setDoneHandler(
MakeDoneHandler<Result>(sender(), std::move(callback)));
return *this;
}
[[nodiscard]] SpecificRequestBuilder &done(
FnMut<void(
const typename Request::ResponseType &result)> callback) {
setDoneHandler(
MakeDoneHandler<Result>(sender(), std::move(callback)));
return *this;
}
[[nodiscard]] SpecificRequestBuilder &fail(
Fn<void(
const RPCError &error,
mtpRequestId requestId)> callback) noexcept {
setFailHandler(std::move(callback)); setFailHandler(std::move(callback));
return *this; return *this;
} }
[[nodiscard]] SpecificRequestBuilder &fail(FnMut<void(const RPCError &error, mtpRequestId requestId)> callback) noexcept { [[nodiscard]] SpecificRequestBuilder &fail(
Fn<void(
const RPCError &error,
const Response &response)> callback) noexcept {
setFailHandler(std::move(callback)); setFailHandler(std::move(callback));
return *this; return *this;
} }
[[nodiscard]] SpecificRequestBuilder &fail(
Fn<void()> callback) noexcept {
setFailHandler(std::move(callback));
return *this;
}
[[nodiscard]] SpecificRequestBuilder &fail(
Fn<void(const RPCError &error)> callback) noexcept {
setFailHandler(std::move(callback));
return *this;
}
[[nodiscard]] SpecificRequestBuilder &handleFloodErrors() noexcept { [[nodiscard]] SpecificRequestBuilder &handleFloodErrors() noexcept {
setFailSkipPolicy(FailSkipPolicy::HandleFlood); setFailSkipPolicy(FailSkipPolicy::HandleFlood);
return *this; return *this;

View file

@ -225,13 +225,6 @@ void Session::start() {
_shiftedDcId); _shiftedDcId);
} }
bool Session::rpcErrorOccured(
mtpRequestId requestId,
const RPCFailHandlerPtr &onFail,
const RPCError &error) { // return true if need to clean request data
return _instance->rpcErrorOccured(requestId, onFail, error);
}
void Session::restart() { void Session::restart() {
if (_killed) { if (_killed) {
DEBUG_LOG(("Session Error: can't restart a killed session")); DEBUG_LOG(("Session Error: can't restart a killed session"));
@ -560,25 +553,17 @@ void Session::tryToReceive() {
} }
while (true) { while (true) {
auto lock = QWriteLocker(_data->haveReceivedMutex()); auto lock = QWriteLocker(_data->haveReceivedMutex());
const auto responses = base::take(_data->haveReceivedResponses()); const auto messages = base::take(_data->haveReceivedMessages());
const auto updates = base::take(_data->haveReceivedUpdates());
lock.unlock(); lock.unlock();
if (responses.empty() && updates.empty()) { if (messages.empty()) {
break; break;
} }
for (const auto &[requestId, response] : responses) { for (const auto &message : messages) {
_instance->execCallback( if (message.requestId) {
requestId, _instance->processCallback(message);
response.constData(), } else if (_shiftedDcId == BareDcId(_shiftedDcId)) {
response.constData() + response.size()); // Process updates only in main session.
} _instance->processUpdate(message);
// Call globalCallback only in main session.
if (_shiftedDcId == BareDcId(_shiftedDcId)) {
for (const auto &update : updates) {
_instance->globalCallback(
update.constData(),
update.constData() + update.size());
} }
} }
} }

View file

@ -84,11 +84,8 @@ public:
base::flat_map<mtpMsgId, SerializedRequest> &haveSentMap() { base::flat_map<mtpMsgId, SerializedRequest> &haveSentMap() {
return _haveSent; return _haveSent;
} }
base::flat_map<mtpRequestId, mtpBuffer> &haveReceivedResponses() { std::vector<Response> &haveReceivedMessages() {
return _receivedResponses; return _receivedMessages;
}
std::vector<mtpBuffer> &haveReceivedUpdates() {
return _receivedUpdates;
} }
// SessionPrivate -> Session interface. // SessionPrivate -> Session interface.
@ -128,8 +125,7 @@ private:
base::flat_map<mtpMsgId, SerializedRequest> _haveSent; // map of msg_id -> request, that was sent base::flat_map<mtpMsgId, SerializedRequest> _haveSent; // map of msg_id -> request, that was sent
QReadWriteLock _haveSentLock; QReadWriteLock _haveSentLock;
base::flat_map<mtpRequestId, mtpBuffer> _receivedResponses; // map of request_id -> response that should be processed in the main thread std::vector<Response> _receivedMessages; // list of responses / updates that should be processed in the main thread
std::vector<mtpBuffer> _receivedUpdates; // list of updates that should be processed in the main thread
QReadWriteLock _haveReceivedLock; QReadWriteLock _haveReceivedLock;
}; };
@ -192,11 +188,6 @@ private:
void killConnection(); void killConnection();
bool rpcErrorOccured(
mtpRequestId requestId,
const RPCFailHandlerPtr &onFail,
const RPCError &err);
[[nodiscard]] bool releaseGenericKeyCreationOnDone( [[nodiscard]] bool releaseGenericKeyCreationOnDone(
const AuthKeyPtr &temporaryKey, const AuthKeyPtr &temporaryKey,
const AuthKeyPtr &persistentKeyUsedForBind); const AuthKeyPtr &persistentKeyUsedForBind);

View file

@ -1363,7 +1363,12 @@ void SessionPrivate::handleReceived() {
).arg(_encryptionKey->keyId())); ).arg(_encryptionKey->keyId()));
if (_receivedMessageIds.registerMsgId(msgId, needAck)) { if (_receivedMessageIds.registerMsgId(msgId, needAck)) {
res = handleOneReceived(from, end, msgId, serverTime, serverSalt, badTime); res = handleOneReceived(from, end, msgId, {
.outerMsgId = msgId,
.serverSalt = serverSalt,
.serverTime = serverTime,
.badTime = badTime,
});
} }
_receivedMessageIds.shrink(); _receivedMessageIds.shrink();
@ -1374,12 +1379,11 @@ void SessionPrivate::handleReceived() {
} }
auto lock = QReadLocker(_sessionData->haveReceivedMutex()); auto lock = QReadLocker(_sessionData->haveReceivedMutex());
const auto tryToReceive = !_sessionData->haveReceivedResponses().empty() const auto tryToReceive = !_sessionData->haveReceivedMessages().empty();
|| !_sessionData->haveReceivedUpdates().empty();
lock.unlock(); lock.unlock();
if (tryToReceive) { if (tryToReceive) {
DEBUG_LOG(("MTP Info: queueTryToReceive() - need to parse in another thread, %1 responses, %2 updates.").arg(_sessionData->haveReceivedResponses().size()).arg(_sessionData->haveReceivedUpdates().size())); DEBUG_LOG(("MTP Info: queueTryToReceive() - need to parse in another thread, %1 messages.").arg(_sessionData->haveReceivedMessages().size()));
_sessionData->queueTryToReceive(); _sessionData->queueTryToReceive();
} }
@ -1410,9 +1414,7 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
const mtpPrime *from, const mtpPrime *from,
const mtpPrime *end, const mtpPrime *end,
uint64 msgId, uint64 msgId,
int32 serverTime, OuterInfo info) {
uint64 serverSalt,
bool badTime) {
Expects(from < end); Expects(from < end);
switch (mtpTypeId(*from)) { switch (mtpTypeId(*from)) {
@ -1423,7 +1425,7 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
if (response.empty()) { if (response.empty()) {
return HandleResult::RestartConnection; return HandleResult::RestartConnection;
} }
return handleOneReceived(response.data(), response.data() + response.size(), msgId, serverTime, serverSalt, badTime); return handleOneReceived(response.data(), response.data() + response.size(), msgId, info);
} }
case mtpc_msg_container: { case mtpc_msg_container: {
@ -1475,8 +1477,8 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
auto res = HandleResult::Success; // if no need to handle, then succeed auto res = HandleResult::Success; // if no need to handle, then succeed
if (_receivedMessageIds.registerMsgId(inMsgId.v, needAck)) { if (_receivedMessageIds.registerMsgId(inMsgId.v, needAck)) {
res = handleOneReceived(from, otherEnd, inMsgId.v, serverTime, serverSalt, badTime); res = handleOneReceived(from, otherEnd, inMsgId.v, info);
badTime = false; info.badTime = false;
} }
if (res != HandleResult::Success) { if (res != HandleResult::Success) {
return res; return res;
@ -1495,15 +1497,15 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
DEBUG_LOG(("Message Info: acks received, ids: %1" DEBUG_LOG(("Message Info: acks received, ids: %1"
).arg(LogIdsVector(ids))); ).arg(LogIdsVector(ids)));
if (ids.isEmpty()) { if (ids.isEmpty()) {
return badTime ? HandleResult::Ignored : HandleResult::Success; return info.badTime ? HandleResult::Ignored : HandleResult::Success;
} }
if (badTime) { if (info.badTime) {
if (!requestsFixTimeSalt(ids, serverTime, serverSalt)) { if (!requestsFixTimeSalt(ids, info)) {
return HandleResult::Ignored; return HandleResult::Ignored;
} }
} else { } else {
correctUnixtimeByFastRequest(ids, serverTime); correctUnixtimeByFastRequest(ids, info.serverTime);
} }
requestsAcked(ids); requestsAcked(ids);
} return HandleResult::Success; } return HandleResult::Success;
@ -1546,28 +1548,28 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
if (!wasSent(resendId)) { if (!wasSent(resendId)) {
DEBUG_LOG(("Message Error: " DEBUG_LOG(("Message Error: "
"such message was not sent recently %1").arg(resendId)); "such message was not sent recently %1").arg(resendId));
return badTime return info.badTime
? HandleResult::Ignored ? HandleResult::Ignored
: HandleResult::Success; : HandleResult::Success;
} }
if (needResend) { // bad msg_id or bad container if (needResend) { // bad msg_id or bad container
if (serverSalt) { if (info.serverSalt) {
_sessionSalt = serverSalt; _sessionSalt = info.serverSalt;
} }
correctUnixtimeWithBadLocal(serverTime); correctUnixtimeWithBadLocal(info.serverTime);
DEBUG_LOG(("Message Info: unixtime updated, now %1, resending in container...").arg(serverTime)); DEBUG_LOG(("Message Info: unixtime updated, now %1, resending in container...").arg(info.serverTime));
resend(resendId, 0, true); resend(resendId, 0, true);
} else { // must create new session, because msg_id and msg_seqno are inconsistent } else { // must create new session, because msg_id and msg_seqno are inconsistent
if (badTime) { if (info.badTime) {
if (serverSalt) { if (info.serverSalt) {
_sessionSalt = serverSalt; _sessionSalt = info.serverSalt;
} }
correctUnixtimeWithBadLocal(serverTime); correctUnixtimeWithBadLocal(info.serverTime);
badTime = false; info.badTime = false;
} }
LOG(("Message Info: bad message notification received, msgId %1, error_code %2").arg(data.vbad_msg_id().v).arg(errorCode)); LOG(("Message Info: bad message notification received, msgId %1, error_code %2").arg(data.vbad_msg_id().v).arg(errorCode));
return HandleResult::ResetSession; return HandleResult::ResetSession;
@ -1582,20 +1584,24 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
).arg(badMsgId ).arg(badMsgId
).arg(errorCode ).arg(errorCode
).arg(requestId)); ).arg(requestId));
auto response = mtpBuffer(); auto reply = mtpBuffer();
MTPRpcError(MTP_rpc_error( MTPRpcError(MTP_rpc_error(
MTP_int(500), MTP_int(500),
MTP_string("PROTOCOL_ERROR") MTP_string("PROTOCOL_ERROR")
)).write(response); )).write(reply);
// Save rpc_error for processing in the main thread. // Save rpc_error for processing in the main thread.
QWriteLocker locker(_sessionData->haveReceivedMutex()); QWriteLocker locker(_sessionData->haveReceivedMutex());
_sessionData->haveReceivedResponses().emplace(requestId, response); _sessionData->haveReceivedMessages().push_back({
.reply = std::move(reply),
.outerMsgId = info.outerMsgId,
.requestId = requestId,
});
} else { } else {
DEBUG_LOG(("Message Error: " DEBUG_LOG(("Message Error: "
"such message was not sent recently %1").arg(badMsgId)); "such message was not sent recently %1").arg(badMsgId));
} }
return badTime return info.badTime
? HandleResult::Ignored ? HandleResult::Ignored
: HandleResult::Success; : HandleResult::Success;
} }
@ -1612,19 +1618,19 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
const auto resendId = data.vbad_msg_id().v; const auto resendId = data.vbad_msg_id().v;
if (!wasSent(resendId)) { if (!wasSent(resendId)) {
DEBUG_LOG(("Message Error: such message was not sent recently %1").arg(resendId)); DEBUG_LOG(("Message Error: such message was not sent recently %1").arg(resendId));
return (badTime ? HandleResult::Ignored : HandleResult::Success); return (info.badTime ? HandleResult::Ignored : HandleResult::Success);
} }
_sessionSalt = data.vnew_server_salt().v; _sessionSalt = data.vnew_server_salt().v;
correctUnixtimeWithBadLocal(serverTime); correctUnixtimeWithBadLocal(info.serverTime);
if (setState(ConnectedState, ConnectingState)) { if (setState(ConnectedState, ConnectingState)) {
resendAll(); resendAll();
} }
badTime = false; info.badTime = false;
DEBUG_LOG(("Message Info: unixtime updated, now %1, server_salt updated, now %2, resending...").arg(serverTime).arg(serverSalt)); DEBUG_LOG(("Message Info: unixtime updated, now %1, server_salt updated, now %2, resending...").arg(info.serverTime).arg(info.serverSalt));
resend(resendId); resend(resendId);
} return HandleResult::Success; } return HandleResult::Success;
@ -1642,17 +1648,19 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
const auto i = _stateAndResendRequests.find(reqMsgId); const auto i = _stateAndResendRequests.find(reqMsgId);
if (i == _stateAndResendRequests.end()) { if (i == _stateAndResendRequests.end()) {
DEBUG_LOG(("Message Error: such message was not sent recently %1").arg(reqMsgId)); DEBUG_LOG(("Message Error: such message was not sent recently %1").arg(reqMsgId));
return (badTime ? HandleResult::Ignored : HandleResult::Success); return info.badTime
? HandleResult::Ignored
: HandleResult::Success;
} }
if (badTime) { if (info.badTime) {
if (serverSalt) { if (info.serverSalt) {
_sessionSalt = serverSalt; // requestsFixTimeSalt with no lookup _sessionSalt = info.serverSalt; // requestsFixTimeSalt with no lookup
} }
correctUnixtimeWithBadLocal(serverTime); correctUnixtimeWithBadLocal(info.serverTime);
DEBUG_LOG(("Message Info: unixtime updated from mtpc_msgs_state_info, now %1").arg(serverTime)); DEBUG_LOG(("Message Info: unixtime updated from mtpc_msgs_state_info, now %1").arg(info.serverTime));
badTime = false; info.badTime = false;
} }
const auto originalRequest = i->second; const auto originalRequest = i->second;
Assert(originalRequest->size() > 8); Assert(originalRequest->size() > 8);
@ -1680,7 +1688,7 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
} return HandleResult::Success; } return HandleResult::Success;
case mtpc_msgs_all_info: { case mtpc_msgs_all_info: {
if (badTime) { if (info.badTime) {
DEBUG_LOG(("Message Info: skipping with bad time...")); DEBUG_LOG(("Message Info: skipping with bad time..."));
return HandleResult::Ignored; return HandleResult::Ignored;
} }
@ -1707,9 +1715,9 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
DEBUG_LOG(("Message Info: msg detailed info, sent msgId %1, answerId %2, status %3, bytes %4").arg(data.vmsg_id().v).arg(data.vanswer_msg_id().v).arg(data.vstatus().v).arg(data.vbytes().v)); DEBUG_LOG(("Message Info: msg detailed info, sent msgId %1, answerId %2, status %3, bytes %4").arg(data.vmsg_id().v).arg(data.vanswer_msg_id().v).arg(data.vstatus().v).arg(data.vbytes().v));
QVector<MTPlong> ids(1, data.vmsg_id()); QVector<MTPlong> ids(1, data.vmsg_id());
if (badTime) { if (info.badTime) {
if (requestsFixTimeSalt(ids, serverTime, serverSalt)) { if (requestsFixTimeSalt(ids, info)) {
badTime = false; info.badTime = false;
} else { } else {
DEBUG_LOG(("Message Info: error, such message was not sent recently %1").arg(data.vmsg_id().v)); DEBUG_LOG(("Message Info: error, such message was not sent recently %1").arg(data.vmsg_id().v));
return HandleResult::Ignored; return HandleResult::Ignored;
@ -1727,7 +1735,7 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
} return HandleResult::Success; } return HandleResult::Success;
case mtpc_msg_new_detailed_info: { case mtpc_msg_new_detailed_info: {
if (badTime) { if (info.badTime) {
DEBUG_LOG(("Message Info: skipping msg_new_detailed_info with bad time...")); DEBUG_LOG(("Message Info: skipping msg_new_detailed_info with bad time..."));
return HandleResult::Ignored; return HandleResult::Ignored;
} }
@ -1763,9 +1771,9 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
DEBUG_LOG(("RPC Info: response received for %1, queueing...").arg(requestMsgId)); DEBUG_LOG(("RPC Info: response received for %1, queueing...").arg(requestMsgId));
QVector<MTPlong> ids(1, reqMsgId); QVector<MTPlong> ids(1, reqMsgId);
if (badTime) { if (info.badTime) {
if (requestsFixTimeSalt(ids, serverTime, serverSalt)) { if (requestsFixTimeSalt(ids, info)) {
badTime = false; info.badTime = false;
} else { } else {
DEBUG_LOG(("Message Info: error, such message was not sent recently %1").arg(requestMsgId)); DEBUG_LOG(("Message Info: error, such message was not sent recently %1").arg(requestMsgId));
return HandleResult::Ignored; return HandleResult::Ignored;
@ -1804,7 +1812,11 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
if (requestId && requestId != mtpRequestId(0xFFFFFFFF)) { if (requestId && requestId != mtpRequestId(0xFFFFFFFF)) {
// Save rpc_result for processing in the main thread. // Save rpc_result for processing in the main thread.
QWriteLocker locker(_sessionData->haveReceivedMutex()); QWriteLocker locker(_sessionData->haveReceivedMutex());
_sessionData->haveReceivedResponses().emplace(requestId, response); _sessionData->haveReceivedMessages().push_back({
.reply = std::move(response),
.outerMsgId = info.outerMsgId,
.requestId = requestId,
});
} else { } else {
DEBUG_LOG(("RPC Info: requestId not found for msgId %1").arg(requestMsgId)); DEBUG_LOG(("RPC Info: requestId not found for msgId %1").arg(requestMsgId));
} }
@ -1818,9 +1830,9 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
} }
const auto &data(msg.c_new_session_created()); const auto &data(msg.c_new_session_created());
if (badTime) { if (info.badTime) {
if (requestsFixTimeSalt(QVector<MTPlong>(1, data.vfirst_msg_id()), serverTime, serverSalt)) { if (requestsFixTimeSalt(QVector<MTPlong>(1, data.vfirst_msg_id()), info)) {
badTime = false; info.badTime = false;
} else { } else {
DEBUG_LOG(("Message Info: error, such message was not sent recently %1").arg(data.vfirst_msg_id().v)); DEBUG_LOG(("Message Info: error, such message was not sent recently %1").arg(data.vfirst_msg_id().v));
return HandleResult::Ignored; return HandleResult::Ignored;
@ -1853,7 +1865,10 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
// Notify main process about new session - need to get difference. // Notify main process about new session - need to get difference.
QWriteLocker locker(_sessionData->haveReceivedMutex()); QWriteLocker locker(_sessionData->haveReceivedMutex());
_sessionData->haveReceivedUpdates().push_back(mtpBuffer(update)); _sessionData->haveReceivedMessages().push_back({
.reply = update,
.outerMsgId = info.outerMsgId,
});
} return HandleResult::Success; } return HandleResult::Success;
case mtpc_pong: { case mtpc_pong: {
@ -1875,9 +1890,9 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
} }
QVector<MTPlong> ids(1, data.vmsg_id()); QVector<MTPlong> ids(1, data.vmsg_id());
if (badTime) { if (info.badTime) {
if (requestsFixTimeSalt(ids, serverTime, serverSalt)) { if (requestsFixTimeSalt(ids, info)) {
badTime = false; info.badTime = false;
} else { } else {
return HandleResult::Ignored; return HandleResult::Ignored;
} }
@ -1887,7 +1902,7 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
} }
if (badTime) { if (info.badTime) {
DEBUG_LOG(("Message Error: bad time in updates cons, must create new session")); DEBUG_LOG(("Message Error: bad time in updates cons, must create new session"));
return HandleResult::ResetSession; return HandleResult::ResetSession;
} }
@ -1900,7 +1915,10 @@ SessionPrivate::HandleResult SessionPrivate::handleOneReceived(
// Notify main process about the new updates. // Notify main process about the new updates.
QWriteLocker locker(_sessionData->haveReceivedMutex()); QWriteLocker locker(_sessionData->haveReceivedMutex());
_sessionData->haveReceivedUpdates().push_back(mtpBuffer(update)); _sessionData->haveReceivedMessages().push_back({
.reply = update,
.outerMsgId = info.outerMsgId,
});
} else { } else {
LOG(("Message Error: unexpected updates in dcType: %1" LOG(("Message Error: unexpected updates in dcType: %1"
).arg(static_cast<int>(_currentDcType))); ).arg(static_cast<int>(_currentDcType)));
@ -1991,14 +2009,14 @@ mtpBuffer SessionPrivate::ungzip(const mtpPrime *from, const mtpPrime *end) cons
return result; return result;
} }
bool SessionPrivate::requestsFixTimeSalt(const QVector<MTPlong> &ids, int32 serverTime, uint64 serverSalt) { bool SessionPrivate::requestsFixTimeSalt(const QVector<MTPlong> &ids, const OuterInfo &info) {
for (const auto &id : ids) { for (const auto &id : ids) {
if (wasSent(id.v)) { if (wasSent(id.v)) {
// Found such msg_id in recent acked or in recent sent requests. // Found such msg_id in recent acked or in recent sent requests.
if (serverSalt) { if (info.serverSalt) {
_sessionSalt = serverSalt; _sessionSalt = info.serverSalt;
} }
correctUnixtimeWithBadLocal(serverTime); correctUnixtimeWithBadLocal(info.serverTime);
return true; return true;
} }
} }
@ -2063,7 +2081,7 @@ void SessionPrivate::requestsAcked(const QVector<MTPlong> &ids, bool byResponse)
if (const auto i = haveSent.find(msgId); i != end(haveSent)) { if (const auto i = haveSent.find(msgId); i != end(haveSent)) {
const auto requestId = i->second->requestId; const auto requestId = i->second->requestId;
if (!byResponse && _instance->hasCallbacks(requestId)) { if (!byResponse && _instance->hasCallback(requestId)) {
DEBUG_LOG(("Message Info: ignoring ACK for msgId %1 because request %2 requires a response").arg(msgId).arg(requestId)); DEBUG_LOG(("Message Info: ignoring ACK for msgId %1 because request %2 requires a response").arg(msgId).arg(requestId));
continue; continue;
} }
@ -2076,7 +2094,7 @@ void SessionPrivate::requestsAcked(const QVector<MTPlong> &ids, bool byResponse)
if (const auto i = _resendingIds.find(msgId); i != end(_resendingIds)) { if (const auto i = _resendingIds.find(msgId); i != end(_resendingIds)) {
const auto requestId = i->second; const auto requestId = i->second;
if (!byResponse && _instance->hasCallbacks(requestId)) { if (!byResponse && _instance->hasCallback(requestId)) {
DEBUG_LOG(("Message Info: ignoring ACK for msgId %1 because request %2 requires a response").arg(msgId).arg(requestId)); DEBUG_LOG(("Message Info: ignoring ACK for msgId %1 because request %2 requires a response").arg(msgId).arg(requestId));
continue; continue;
} }

View file

@ -120,7 +120,17 @@ private:
bool needAnyResponse); bool needAnyResponse);
mtpRequestId wasSent(mtpMsgId msgId) const; mtpRequestId wasSent(mtpMsgId msgId) const;
[[nodiscard]] HandleResult handleOneReceived(const mtpPrime *from, const mtpPrime *end, uint64 msgId, int32 serverTime, uint64 serverSalt, bool badTime); struct OuterInfo {
mtpMsgId outerMsgId = 0;
uint64 serverSalt = 0;
int32 serverTime = 0;
bool badTime = false;
};
[[nodiscard]] HandleResult handleOneReceived(
const mtpPrime *from,
const mtpPrime *end,
uint64 msgId,
OuterInfo info);
[[nodiscard]] HandleResult handleBindResponse( [[nodiscard]] HandleResult handleBindResponse(
mtpMsgId requestMsgId, mtpMsgId requestMsgId,
const mtpBuffer &response); const mtpBuffer &response);
@ -137,7 +147,7 @@ private:
const bytes::vector &protocolSecret); const bytes::vector &protocolSecret);
// if badTime received - search for ids in sessionData->haveSent and sessionData->wereAcked and sync time/salt, return true if found // if badTime received - search for ids in sessionData->haveSent and sessionData->wereAcked and sync time/salt, return true if found
bool requestsFixTimeSalt(const QVector<MTPlong> &ids, int32 serverTime, uint64 serverSalt); bool requestsFixTimeSalt(const QVector<MTPlong> &ids, const OuterInfo &info);
// if we had a confirmed fast request use its unixtime as a correct one. // if we had a confirmed fast request use its unixtime as a correct one.
void correctUnixtimeByFastRequest( void correctUnixtimeByFastRequest(