diff --git a/src/common/zstd_compression.cpp b/src/common/zstd_compression.cpp index 19bb37c82..9aad95902 100644 --- a/src/common/zstd_compression.cpp +++ b/src/common/zstd_compression.cpp @@ -45,14 +45,73 @@ std::vector DecompressDataZSTD(std::span compressed) { // 16 MB is a very generous limit for a single game packet. constexpr u64 MAX_REASONABLE_PACKET_SIZE = 16 * 1024 * 1024; - // ZSTD_getFrameContentSize can return special values if the size isn't in the header - // or if there's an error. We must check for these AND our own sanity limit. - if (decompressed_size == ZSTD_CONTENTSIZE_ERROR || - decompressed_size == ZSTD_CONTENTSIZE_UNKNOWN || - decompressed_size > MAX_REASONABLE_PACKET_SIZE) { + // ZSTD_CONTENTSIZE_ERROR indicates a corrupted frame or invalid data - reject it + // ZSTD_CONTENTSIZE_UNKNOWN means the size isn't in the header but decompression can still work + if (decompressed_size == ZSTD_CONTENTSIZE_ERROR) { + LOG_ERROR(Common, "Received network packet with corrupted or invalid ZSTD frame"); + return {}; + } - LOG_ERROR(Common, "Received network packet with invalid or oversized decompressed_size: {}", decompressed_size); - return {}; // Return an empty vector to signal a graceful failure. + // Reject packets that claim to be larger than reasonable + if (decompressed_size != ZSTD_CONTENTSIZE_UNKNOWN && decompressed_size > MAX_REASONABLE_PACKET_SIZE) { + LOG_ERROR(Common, "Received network packet with oversized decompressed_size: {}", decompressed_size); + return {}; + } + + // When size is unknown, use streaming decompression with a reasonable initial buffer + if (decompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) { + // Use streaming decompression for unknown size + ZSTD_DCtx* dctx = ZSTD_createDCtx(); + if (!dctx) { + LOG_ERROR(Common, "Failed to create ZSTD decompression context"); + return {}; + } + + std::vector decompressed; + decompressed.resize(64 * 1024); // Start with 64KB buffer + + ZSTD_inBuffer input = {compressed.data(), compressed.size(), 0}; + ZSTD_outBuffer output = {decompressed.data(), decompressed.size(), 0}; + + while (input.pos < input.size) { + const size_t ret = ZSTD_decompressStream(dctx, &output, &input); + if (ZSTD_isError(ret)) { + LOG_ERROR(Common, "ZSTD streaming decompression failed with error: {}", ZSTD_getErrorName(ret)); + ZSTD_freeDCtx(dctx); + return {}; + } + + // If ret == 0, decompression is complete + if (ret == 0) { + break; + } + + // If output buffer is full but we haven't consumed all input, need more space + if (output.pos >= output.size && input.pos < input.size) { + // Double the buffer size, up to maximum + if (decompressed.size() > MAX_REASONABLE_PACKET_SIZE) { + LOG_ERROR(Common, "ZSTD decompressed size exceeds maximum reasonable packet size"); + ZSTD_freeDCtx(dctx); + return {}; + } + const size_t old_size = decompressed.size(); + decompressed.resize(std::min(old_size * 2, static_cast(MAX_REASONABLE_PACKET_SIZE))); + output.dst = decompressed.data(); + output.size = decompressed.size(); + // Keep output.pos as is - it points to where we continue writing + } + } + + // Ensure all data was consumed + if (input.pos < input.size) { + LOG_ERROR(Common, "ZSTD streaming decompression: not all input was consumed"); + ZSTD_freeDCtx(dctx); + return {}; + } + + decompressed.resize(output.pos); + ZSTD_freeDCtx(dctx); + return decompressed; } std::vector decompressed(decompressed_size); diff --git a/src/core/internal_network/socket_proxy.cpp b/src/core/internal_network/socket_proxy.cpp index b43de3a2c..4d2f4bb8f 100644 --- a/src/core/internal_network/socket_proxy.cpp +++ b/src/core/internal_network/socket_proxy.cpp @@ -46,6 +46,14 @@ void ProxySocket::HandleProxyPacket(const ProxyPacket& packet) { auto decompressed = packet; decompressed.data = Common::Compression::DecompressDataZSTD(packet.data); + // Check if decompression failed (returns empty vector on error) + if (decompressed.data.empty() && !packet.data.empty()) { + stats.packets_dropped++; + LOG_WARNING(Network, "Dropped packet: ZSTD decompression failed. Stats: sent={}, recv={}, dropped={}", + stats.packets_sent, stats.packets_received, stats.packets_dropped); + return; + } + std::lock_guard guard(packets_mutex); received_packets.push(decompressed); stats.packets_received++; @@ -204,8 +212,18 @@ void ProxySocket::SendPacket(ProxyPacket& packet) { if (auto room_member = room_network.GetRoomMember().lock()) { if (room_member->IsConnected()) { const size_t original_size = packet.data.size(); + const std::vector original_data = packet.data; // Save original for potential fallback packet.data = Common::Compression::CompressDataZSTDDefault(packet.data.data(), packet.data.size()); + + // Check if compression failed (returns empty vector on error) + if (packet.data.empty() && !original_data.empty()) { + stats.packets_dropped++; + LOG_ERROR(Network, "Failed to compress packet: ZSTD compression failed. Dropping packet. Stats: sent={}, dropped={}", + stats.packets_sent, stats.packets_dropped); + return; + } + room_member->SendProxyPacket(packet); stats.packets_sent++; @@ -260,14 +278,7 @@ std::pair ProxySocket::SendTo(u32 flags, std::span message packet.data.clear(); std::copy(message.begin(), message.end(), std::back_inserter(packet.data)); - // Determine if packet should use unreliable delivery for better latency - // Use unreliable delivery for: - // 1. Small, frequent game data packets (< 1200 bytes for typical MTU) - // 2. UDP protocol packets (most game traffic) - // 3. Non-broadcast packets (broadcast should be reliable for coordination) - const bool is_game_data = protocol == Protocol::UDP && message.size() < 1200 && !packet.broadcast; - packet.reliable = !is_game_data; - + // All packets use reliable delivery SendPacket(packet); return {static_cast(message.size()), Errno::SUCCESS};