fix: Add safety checks to ZSTD decompression and improve HTTP client

- Add maximum packet size limit (16 MB) to prevent memory exhaustion
- Add empty input validation for ZSTD decompression
- Improve error handling with detailed logging
- Increase HTTP timeout from 30 to 60 seconds
- Enable HTTP redirect following and keep-alive connections

Signed-off-by: Zephyron <zephyron@citron-emu.org>
This commit is contained in:
Zephyron
2025-12-03 12:03:06 +10:00
parent 54cba480e6
commit 240b8f7aef
2 changed files with 32 additions and 2 deletions

View File

@@ -1,9 +1,11 @@
// SPDX-FileCopyrightText: Copyright 2019 yuzu Emulator Project // SPDX-FileCopyrightText: Copyright 2019 yuzu Emulator Project
// SPDX-FileCopyrightText: Copyright 2025 citron Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later // SPDX-License-Identifier: GPL-2.0-or-later
#include <algorithm> #include <algorithm>
#include <zstd.h> #include <zstd.h>
#include "common/logging/log.h"
#include "common/zstd_compression.h" #include "common/zstd_compression.h"
namespace Common::Compression { namespace Common::Compression {
@@ -32,17 +34,43 @@ std::vector<u8> CompressDataZSTDDefault(const u8* source, std::size_t source_siz
} }
std::vector<u8> DecompressDataZSTD(std::span<const u8> compressed) { std::vector<u8> DecompressDataZSTD(std::span<const u8> compressed) {
if (compressed.empty()) {
return {};
}
const std::size_t decompressed_size = const std::size_t decompressed_size =
ZSTD_getFrameContentSize(compressed.data(), compressed.size()); ZSTD_getFrameContentSize(compressed.data(), compressed.size());
// Define a reasonable maximum size for a decompressed network packet.
// 16 MB is a very generous limit for a single game packet.
constexpr u64 MAX_REASONABLE_PACKET_SIZE = 16 * 1024 * 1024;
// ZSTD_getFrameContentSize can return special values if the size isn't in the header
// or if there's an error. We must check for these AND our own sanity limit.
if (decompressed_size == ZSTD_CONTENTSIZE_ERROR ||
decompressed_size == ZSTD_CONTENTSIZE_UNKNOWN ||
decompressed_size > MAX_REASONABLE_PACKET_SIZE) {
LOG_ERROR(Common, "Received network packet with invalid or oversized decompressed_size: {}", decompressed_size);
return {}; // Return an empty vector to signal a graceful failure.
}
std::vector<u8> decompressed(decompressed_size); std::vector<u8> decompressed(decompressed_size);
const std::size_t uncompressed_result_size = ZSTD_decompress( const std::size_t uncompressed_result_size = ZSTD_decompress(
decompressed.data(), decompressed.size(), compressed.data(), compressed.size()); decompressed.data(), decompressed.size(), compressed.data(), compressed.size());
if (decompressed_size != uncompressed_result_size || ZSTD_isError(uncompressed_result_size)) { if (ZSTD_isError(uncompressed_result_size)) { // check the result of decompress
// Decompression failed // Decompression failed
LOG_ERROR(Common, "ZSTD_decompress failed with error: {}", ZSTD_getErrorName(uncompressed_result_size));
return {}; return {};
} }
if (decompressed_size != uncompressed_result_size) {
LOG_ERROR(Common, "ZSTD decompressed size mismatch. Expected {}, got {}", decompressed_size, uncompressed_result_size);
return {};
}
return decompressed; return decompressed;
} }

View File

@@ -27,7 +27,7 @@ namespace WebService {
constexpr std::array<const char, 1> API_VERSION{'1'}; constexpr std::array<const char, 1> API_VERSION{'1'};
constexpr std::size_t TIMEOUT_SECONDS = 30; constexpr std::size_t TIMEOUT_SECONDS = 60;
struct Client::Impl { struct Client::Impl {
Impl(std::string host_, std::string username_, std::string token_) Impl(std::string host_, std::string username_, std::string token_)
@@ -80,6 +80,8 @@ struct Client::Impl {
// Create a new client for each request. This is the safest approach in a // Create a new client for each request. This is the safest approach in a
// multi-threaded environment as it avoids sharing a single client instance. // multi-threaded environment as it avoids sharing a single client instance.
httplib::Client cli(host.c_str()); httplib::Client cli(host.c_str());
cli.set_follow_location(true);
cli.set_keep_alive(true);
cli.set_connection_timeout(TIMEOUT_SECONDS); cli.set_connection_timeout(TIMEOUT_SECONDS);
cli.set_read_timeout(TIMEOUT_SECONDS); cli.set_read_timeout(TIMEOUT_SECONDS);
cli.set_write_timeout(TIMEOUT_SECONDS); cli.set_write_timeout(TIMEOUT_SECONDS);