diff --git a/Telegram/SourceFiles/storage/file_download.cpp b/Telegram/SourceFiles/storage/file_download.cpp index 3ba5ab641..14a1b298d 100644 --- a/Telegram/SourceFiles/storage/file_download.cpp +++ b/Telegram/SourceFiles/storage/file_download.cpp @@ -84,6 +84,7 @@ constexpr auto kDownloadPhotoPartSize = 64 * 1024; // 64kb for photo constexpr auto kDownloadDocumentPartSize = 128 * 1024; // 128kb for document constexpr auto kMaxFileQueries = 16; // max 16 file parts downloaded at the same time constexpr auto kMaxWebFileQueries = 8; // max 8 http[s] files downloaded at the same time +constexpr auto kDownloadCdnPartSize = 128 * 1024; // 128kb for cdn requests } // namespace @@ -452,10 +453,19 @@ bool mtpFileLoader::loadPart() { } int mtpFileLoader::partSize() const { - if (_locationType == UnknownFileLocation) { - return kDownloadPhotoPartSize; - } - return kDownloadDocumentPartSize; + return kDownloadCdnPartSize; + + // Different part sizes are not supported for now :( + // Because we start downloading with some part size + // and then we get a cdn-redirect where we support only + // fixed part size download for hash checking. + // + //if (_cdnDcId) { + // return kDownloadCdnPartSize; + //} else if (_locationType == UnknownFileLocation) { + // return kDownloadPhotoPartSize; + //} + //return kDownloadDocumentPartSize; } mtpFileLoader::RequestData mtpFileLoader::prepareRequest(int offset) const { @@ -492,6 +502,21 @@ void mtpFileLoader::makeRequest(int offset) { placeSentRequest(send(), requestData); } +void mtpFileLoader::requestMoreCdnFileHashes() { + if (_cdnHashesRequestId || _cdnUncheckedParts.empty()) { + return; + } + + auto offset = _cdnUncheckedParts.cbegin()->first; + auto requestData = RequestData(); + requestData.dcId = _dcId; + requestData.dcIndex = 0; + requestData.offset = offset; + auto shiftedDcId = MTP::downloadDcId(requestData.dcId, requestData.dcIndex); + auto requestId = _cdnHashesRequestId = MTP::send(MTPupload_GetCdnFileHashes(MTP_bytes(_cdnToken), MTP_int(offset)), rpcDone(&mtpFileLoader::getCdnFileHashesDone), rpcFail(&mtpFileLoader::cdnPartFailed), shiftedDcId); + placeSentRequest(requestId, requestData); +} + void mtpFileLoader::normalPartLoaded(const MTPupload_File &result, mtpRequestId requestId) { Expects(result.type() == mtpc_upload_fileCdnRedirect || result.type() == mtpc_upload_file); @@ -550,14 +575,88 @@ void mtpFileLoader::cdnPartLoaded(const MTPupload_CdnFile &result, mtpRequestId auto decryptInPlace = result.c_upload_cdnFile().vbytes.v; MTP::aesCtrEncrypt(decryptInPlace.data(), decryptInPlace.size(), key.data(), &state); auto bytes = gsl::as_bytes(gsl::make_span(decryptInPlace)); - return partLoaded(offset, bytes); + + switch (checkCdnFileHash(offset, bytes)) { + case CheckCdnHashResult::NoHash: { + _cdnUncheckedParts.emplace(offset, decryptInPlace); + requestMoreCdnFileHashes(); + } return; + + case CheckCdnHashResult::Invalid: { + LOG(("API Error: Wrong cdnFileHash for offset %1.").arg(offset)); + cancel(true); + } return; + + case CheckCdnHashResult::Good: { + partLoaded(offset, bytes); + } return; + } + Unexpected("Result of checkCdnFileHash()"); } -void mtpFileLoader::reuploadDone(const MTPBool &result, mtpRequestId requestId) { +mtpFileLoader::CheckCdnHashResult mtpFileLoader::checkCdnFileHash(int offset, base::const_byte_span bytes) { + auto cdnFileHashIt = _cdnFileHashes.find(offset); + if (cdnFileHashIt == _cdnFileHashes.cend()) { + return CheckCdnHashResult::NoHash; + } + auto realHash = hashSha256(bytes.data(), bytes.size()); + if (!base::compare_bytes(gsl::as_bytes(gsl::make_span(realHash)), gsl::as_bytes(gsl::make_span(cdnFileHashIt->second.hash)))) { + return CheckCdnHashResult::Invalid; + } + return CheckCdnHashResult::Good; +} + +void mtpFileLoader::reuploadDone(const MTPVector &result, mtpRequestId requestId) { auto offset = finishSentRequestGetOffset(requestId); + addCdnHashes(result.v); makeRequest(offset); } +void mtpFileLoader::getCdnFileHashesDone(const MTPVector &result, mtpRequestId requestId) { + Expects(_cdnHashesRequestId == requestId); + _cdnHashesRequestId = 0; + + auto offset = finishSentRequestGetOffset(requestId); + addCdnHashes(result.v); + auto someMoreChecked = false; + for (auto i = _cdnUncheckedParts.begin(); i != _cdnUncheckedParts.cend();) { + auto bytes = gsl::as_bytes(gsl::make_span(i->second)); + + switch (checkCdnFileHash(offset, bytes)) { + case CheckCdnHashResult::NoHash: { + ++i; + } break; + + case CheckCdnHashResult::Invalid: { + LOG(("API Error: Wrong cdnFileHash for offset %1.").arg(offset)); + cancel(true); + return; + } break; + + case CheckCdnHashResult::Good: { + someMoreChecked = true; + auto goodOffset = i->first; + auto goodBytes = std::move(i->second); + i = _cdnUncheckedParts.erase(i); + auto finished = (i == _cdnUncheckedParts.cend()); + partLoaded(goodOffset, gsl::as_bytes(gsl::make_span(goodBytes))); + if (finished) { + // Perhaps we were destroyed already?.. + return; + } + } break; + + default: Unexpected("Result of checkCdnFileHash()"); + } + } + if (!someMoreChecked) { + LOG(("API Error: Could not find cdnFileHash for offset %1 after getCdnFileHashes request.").arg(offset)); + cancel(true); + } else { + requestMoreCdnFileHashes(); + } +} + void mtpFileLoader::placeSentRequest(mtpRequestId requestId, const RequestData &requestData) { _downloader->requestedAmountIncrement(requestData.dcId, requestData.dcIndex, partSize()); ++_queue->queriesCount; @@ -669,9 +768,12 @@ bool mtpFileLoader::partFailed(const RPCError &error) { bool mtpFileLoader::cdnPartFailed(const RPCError &error, mtpRequestId requestId) { if (MTP::isDefaultHandledError(error)) return false; + if (requestId == _cdnHashesRequestId) { + _cdnHashesRequestId = 0; + } if (error.type() == qstr("FILE_TOKEN_INVALID") || error.type() == qstr("REQUEST_TOKEN_INVALID")) { auto offset = finishSentRequestGetOffset(requestId); - changeCDNParams(offset, 0, QByteArray(), QByteArray(), QByteArray()); + changeCDNParams(offset, 0, QByteArray(), QByteArray(), QByteArray(), QVector()); return true; } return partFailed(error); @@ -686,10 +788,18 @@ void mtpFileLoader::cancelRequests() { } void mtpFileLoader::switchToCDN(int offset, const MTPDupload_fileCdnRedirect &redirect) { - changeCDNParams(offset, redirect.vdc_id.v, redirect.vfile_token.v, redirect.vencryption_key.v, redirect.vencryption_iv.v); + changeCDNParams(offset, redirect.vdc_id.v, redirect.vfile_token.v, redirect.vencryption_key.v, redirect.vencryption_iv.v, redirect.vcdn_file_hashes.v); } -void mtpFileLoader::changeCDNParams(int offset, MTP::DcId dcId, const QByteArray &token, const QByteArray &encryptionKey, const QByteArray &encryptionIV) { +void mtpFileLoader::addCdnHashes(const QVector &hashes) { + for_const (auto &hash, hashes) { + t_assert(hash.type() == mtpc_cdnFileHash); + auto &data = hash.c_cdnFileHash(); + _cdnFileHashes.emplace(data.voffset.v, CdnFileHash { data.vlimit.v, data.vhash.v }); + } +} + +void mtpFileLoader::changeCDNParams(int offset, MTP::DcId dcId, const QByteArray &token, const QByteArray &encryptionKey, const QByteArray &encryptionIV, const QVector &hashes) { if (dcId != 0 && (encryptionKey.size() != MTP::CTRState::KeySize || encryptionIV.size() != MTP::CTRState::IvecSize)) { LOG(("Message Error: Wrong key (%1) / iv (%2) size in CDN params").arg(encryptionKey.size()).arg(encryptionIV.size())); cancel(true); @@ -704,6 +814,7 @@ void mtpFileLoader::changeCDNParams(int offset, MTP::DcId dcId, const QByteArray _cdnToken = token; _cdnEncryptionKey = encryptionKey; _cdnEncryptionIV = encryptionIV; + addCdnHashes(hashes); if (resendAllRequests && !_sentRequests.empty()) { auto resendOffsets = std::vector(); diff --git a/Telegram/SourceFiles/storage/file_download.h b/Telegram/SourceFiles/storage/file_download.h index 62631d65a..d53299312 100644 --- a/Telegram/SourceFiles/storage/file_download.h +++ b/Telegram/SourceFiles/storage/file_download.h @@ -211,6 +211,12 @@ private: int dcIndex = 0; int offset = 0; }; + struct CdnFileHash { + CdnFileHash(int limit, QByteArray hash) : limit(limit), hash(hash) { + } + int limit = 0; + QByteArray hash; + }; bool tryLoadLocal() override; void cancelRequests() override; @@ -223,7 +229,9 @@ private: void normalPartLoaded(const MTPupload_File &result, mtpRequestId requestId); void webPartLoaded(const MTPupload_WebFile &result, mtpRequestId requestId); void cdnPartLoaded(const MTPupload_CdnFile &result, mtpRequestId requestId); - void reuploadDone(const MTPBool &result, mtpRequestId requestId); + void reuploadDone(const MTPVector &result, mtpRequestId requestId); + void requestMoreCdnFileHashes(); + void getCdnFileHashesDone(const MTPVector &result, mtpRequestId requestId); void partLoaded(int offset, base::const_byte_span bytes); bool partFailed(const RPCError &error); @@ -232,7 +240,15 @@ private: void placeSentRequest(mtpRequestId requestId, const RequestData &requestData); int finishSentRequestGetOffset(mtpRequestId requestId); void switchToCDN(int offset, const MTPDupload_fileCdnRedirect &redirect); - void changeCDNParams(int offset, MTP::DcId dcId, const QByteArray &token, const QByteArray &encryptionKey, const QByteArray &encryptionIV); + void addCdnHashes(const QVector &hashes); + void changeCDNParams(int offset, MTP::DcId dcId, const QByteArray &token, const QByteArray &encryptionKey, const QByteArray &encryptionIV, const QVector &hashes); + + enum class CheckCdnHashResult { + NoHash, + Invalid, + Good, + }; + CheckCdnHashResult checkCdnFileHash(int offset, base::const_byte_span bytes); std::map _sentRequests; @@ -253,6 +269,9 @@ private: QByteArray _cdnToken; QByteArray _cdnEncryptionKey; QByteArray _cdnEncryptionIV; + std::map _cdnFileHashes; + std::map _cdnUncheckedParts; + mtpRequestId _cdnHashesRequestId = 0; };