diff --git a/Cargo.toml b/Cargo.toml index 80af6c7f675..88dde863974 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ bytes = "1.5.0" fxhash = { version = "0.2.1", optional = true } chrono = { version = "0.4.31", default-features = false, features = ["clock", "serde"], optional = true } flate2 = { version = "1.0.28", optional = true } +zstd-safe = { version = "7.2.1", optional = true } reqwest = { version = "0.12.2", default-features = false, features = ["multipart", "stream", "json"], optional = true } tokio-tungstenite = { version = "0.24.0", features = ["url"], optional = true } percent-encoding = { version = "2.3.0", optional = true } @@ -70,6 +71,8 @@ default_no_backend = [ "cache", "chrono", "framework", + "transport_compression_zlib", + "transport_compression_zstd", ] # Enables builder structs to configure Discord HTTP requests. Without this feature, you have to @@ -93,6 +96,10 @@ http = ["dashmap", "mime_guess", "percent-encoding"] # TODO: remove dependeny on utils feature model = ["builder", "http", "utils"] voice_model = ["serenity-voice-model"] +# Enables zlib-stream transport compression of incoming gateway events. +transport_compression_zlib = ["flate2", "gateway"] +# Enables zstd-stream transport compression of incoming gateway events. +transport_compression_zstd = ["zstd-safe", "gateway"] # Enables support for Discord API functionality that's not stable yet, as well as serenity APIs that # are allowed to change even in semver non-breaking updates. unstable = [] diff --git a/src/gateway/client/mod.rs b/src/gateway/client/mod.rs index fc3722c3d67..109e3aa14ac 100644 --- a/src/gateway/client/mod.rs +++ b/src/gateway/client/mod.rs @@ -42,6 +42,7 @@ use tracing::{debug, warn}; pub use self::context::Context; pub use self::event_handler::{EventHandler, FullEvent, RawEventHandler}; +use super::TransportCompression; #[cfg(feature = "cache")] use crate::cache::Cache; #[cfg(feature = "cache")] @@ -82,6 +83,7 @@ pub struct ClientBuilder { raw_event_handler: Option>, presence: PresenceData, wait_time_between_shard_start: Duration, + compression: TransportCompression, } impl ClientBuilder { @@ -116,6 +118,7 @@ impl ClientBuilder { raw_event_handler: None, presence: PresenceData::default(), wait_time_between_shard_start: DEFAULT_WAIT_BETWEEN_SHARD_START, + compression: TransportCompression::None, } } @@ -176,6 +179,12 @@ impl ClientBuilder { self } + /// Sets the compression method to be used when receiving data from the gateway. + pub fn compression(mut self, compression: TransportCompression) -> Self { + self.compression = compression; + self + } + /// Sets the voice gateway handler to be used. It will receive voice events sent over the /// gateway and then consider - based on its settings - whether to dispatch a command. #[cfg(feature = "voice")] @@ -342,6 +351,7 @@ impl IntoFuture for ClientBuilder { presence: Some(presence), max_concurrency, wait_time_between_shard_start: self.wait_time_between_shard_start, + compression: self.compression, }); let client = Client { diff --git a/src/gateway/error.rs b/src/gateway/error.rs index d23387de6ed..a3e0d02e5da 100644 --- a/src/gateway/error.rs +++ b/src/gateway/error.rs @@ -7,7 +7,7 @@ use tokio_tungstenite::tungstenite::protocol::CloseFrame; /// /// Note that - from a user standpoint - there should be no situation in which you manually handle /// these. -#[derive(Clone, Debug)] +#[derive(Debug)] #[non_exhaustive] pub enum Error { /// There was an error building a URL. @@ -50,6 +50,17 @@ pub enum Error { /// If an connection has been established but privileged gateway intents were provided without /// enabling them prior. DisallowedGatewayIntents, + #[cfg(feature = "transport_compression_zlib")] + /// A decompression error from the `flate2` crate. + DecompressZlib(flate2::DecompressError), + #[cfg(feature = "transport_compression_zstd")] + /// A decompression error from zstd. + DecompressZstd(usize), + /// When zstd decompression fails due to corrupted data. + #[cfg(feature = "transport_compression_zstd")] + DecompressZstdCorrupted, + /// When decompressed gateway data is not valid UTF-8. + DecompressUtf8(std::string::FromUtf8Error), } impl fmt::Display for Error { @@ -70,6 +81,15 @@ impl fmt::Display for Error { Self::DisallowedGatewayIntents => { f.write_str("Disallowed gateway intents were provided") }, + #[cfg(feature = "transport_compression_zlib")] + Self::DecompressZlib(inner) => fmt::Display::fmt(&inner, f), + #[cfg(feature = "transport_compression_zstd")] + Self::DecompressZstd(code) => write!(f, "Zstd decompression error: {code}"), + #[cfg(feature = "transport_compression_zstd")] + Self::DecompressZstdCorrupted => { + f.write_str("Zstd decompression error: corrupted data") + }, + Self::DecompressUtf8(inner) => fmt::Display::fmt(&inner, f), } } } diff --git a/src/gateway/sharding/mod.rs b/src/gateway/sharding/mod.rs index 9ce61ce27a5..de23d3e96e1 100644 --- a/src/gateway/sharding/mod.rs +++ b/src/gateway/sharding/mod.rs @@ -44,7 +44,9 @@ use std::fmt; use std::sync::Arc; use std::time::{Duration as StdDuration, Instant}; -use aformat::{aformat, CapStr}; +#[cfg(feature = "transport_compression_zlib")] +use aformat::aformat_into; +use aformat::{aformat, ArrayString, CapStr}; use tokio_tungstenite::tungstenite::error::Error as TungsteniteError; use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame; use tracing::{debug, error, info, trace, warn}; @@ -113,6 +115,7 @@ pub struct Shard { token: SecretString, ws_url: Arc, resume_ws_url: Option, + compression: TransportCompression, pub intents: GatewayIntents, } @@ -129,7 +132,7 @@ impl Shard { /// use std::num::NonZeroU16; /// use std::sync::Arc; /// - /// use serenity::gateway::Shard; + /// use serenity::gateway::{Shard, TransportCompression}; /// use serenity::model::gateway::{GatewayIntents, ShardInfo}; /// use serenity::model::id::ShardId; /// use serenity::secret_string::SecretString; @@ -147,7 +150,15 @@ impl Shard { /// /// // retrieve the gateway response, which contains the URL to connect to /// let gateway = Arc::from(http.get_gateway().await?.url); - /// let shard = Shard::new(gateway, token, shard_info, GatewayIntents::all(), None).await?; + /// let shard = Shard::new( + /// gateway, + /// token, + /// shard_info, + /// GatewayIntents::all(), + /// None, + /// TransportCompression::None, + /// ) + /// .await?; /// /// // at this point, you can create a `loop`, and receive events and match /// // their variants @@ -165,8 +176,9 @@ impl Shard { shard_info: ShardInfo, intents: GatewayIntents, presence: Option, + compression: TransportCompression, ) -> Result { - let client = connect(&ws_url).await?; + let client = connect(&ws_url, compression).await?; let presence = presence.unwrap_or_default(); let last_heartbeat_sent = None; @@ -193,6 +205,7 @@ impl Shard { shard_info, ws_url, resume_ws_url: None, + compression, intents, }) } @@ -748,7 +761,7 @@ impl Shard { // Hello is received. self.stage = ConnectionStage::Connecting; self.started = Instant::now(); - let client = connect(ws_url).await?; + let client = connect(ws_url, self.compression).await?; self.stage = ConnectionStage::Handshake; Ok(client) @@ -807,14 +820,19 @@ impl Shard { } } -async fn connect(base_url: &str) -> Result { - let url = Url::parse(&aformat!("{}?v={}", CapStr::<64>(base_url), constants::GATEWAY_VERSION)) - .map_err(|why| { - warn!("Error building gateway URL with base `{base_url}`: {why:?}"); - Error::Gateway(GatewayError::BuildingUrl) - })?; - - WsClient::connect(url).await +async fn connect(base_url: &str, compression: TransportCompression) -> Result { + let url = Url::parse(&aformat!( + "{}?v={}{}", + CapStr::<64>(base_url), + constants::GATEWAY_VERSION, + compression.query_param() + )) + .map_err(|why| { + warn!("Error building gateway URL with base `{base_url}`: {why:?}"); + Error::Gateway(GatewayError::BuildingUrl) + })?; + + WsClient::connect(url, compression).await } #[derive(Debug)] @@ -954,3 +972,41 @@ impl PartialEq for CollectorCallback { Arc::ptr_eq(&self.0, &other.0) } } + +/// The transport compression method to use. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum TransportCompression { + /// No transport compression. Payload compression will be used instead. + None, + + #[cfg(feature = "transport_compression_zlib")] + /// Use zlib-stream transport compression. + Zlib, + + #[cfg(feature = "transport_compression_zstd")] + /// Use zstd-stream transport compression. + Zstd, +} + +impl TransportCompression { + fn query_param(self) -> ArrayString<21> { + #[cfg_attr( + not(any( + feature = "transport_compression_zlib", + feature = "transport_compression_zstd" + )), + expect(unused_mut) + )] + let mut res = ArrayString::new(); + match self { + Self::None => {}, + #[cfg(feature = "transport_compression_zlib")] + Self::Zlib => aformat_into!(res, "&compress=zlib-stream"), + #[cfg(feature = "transport_compression_zstd")] + Self::Zstd => aformat_into!(res, "&compress=zstd-stream"), + } + + res + } +} diff --git a/src/gateway/sharding/shard_manager.rs b/src/gateway/sharding/shard_manager.rs index 471896d2846..a80aa835879 100644 --- a/src/gateway/sharding/shard_manager.rs +++ b/src/gateway/sharding/shard_manager.rs @@ -11,7 +11,14 @@ use tokio::sync::Mutex; use tokio::time::timeout; use tracing::{info, warn}; -use super::{ShardId, ShardQueue, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo}; +use super::{ + ShardId, + ShardQueue, + ShardQueuer, + ShardQueuerMessage, + ShardRunnerInfo, + TransportCompression, +}; #[cfg(feature = "cache")] use crate::cache::Cache; #[cfg(feature = "framework")] @@ -53,7 +60,12 @@ pub const DEFAULT_WAIT_BETWEEN_SHARD_START: Duration = Duration::from_secs(5); /// use std::sync::{Arc, OnceLock}; /// /// use serenity::gateway::client::EventHandler; -/// use serenity::gateway::{ShardManager, ShardManagerOptions, DEFAULT_WAIT_BETWEEN_SHARD_START}; +/// use serenity::gateway::{ +/// ShardManager, +/// ShardManagerOptions, +/// TransportCompression, +/// DEFAULT_WAIT_BETWEEN_SHARD_START, +/// }; /// use serenity::http::Http; /// use serenity::model::gateway::GatewayIntents; /// use serenity::prelude::*; @@ -88,6 +100,7 @@ pub const DEFAULT_WAIT_BETWEEN_SHARD_START: Duration = Duration::from_secs(5); /// presence: None, /// max_concurrency, /// wait_time_between_shard_start: DEFAULT_WAIT_BETWEEN_SHARD_START, +/// compression: TransportCompression::None, /// }); /// # Ok(()) /// # } @@ -144,6 +157,7 @@ impl ShardManager { #[cfg(feature = "voice")] voice_manager: opt.voice_manager, ws_url: opt.ws_url, + compression: opt.compression, shard_total: opt.shard_total, #[cfg(feature = "cache")] cache: opt.cache, @@ -379,4 +393,5 @@ pub struct ShardManagerOptions { pub max_concurrency: NonZeroU16, /// Number of seconds to wait between starting each shard/set of shards start pub wait_time_between_shard_start: Duration, + pub compression: TransportCompression, } diff --git a/src/gateway/sharding/shard_queuer.rs b/src/gateway/sharding/shard_queuer.rs index 94cad7dc607..27935992c00 100644 --- a/src/gateway/sharding/shard_queuer.rs +++ b/src/gateway/sharding/shard_queuer.rs @@ -17,6 +17,7 @@ use super::{ ShardRunner, ShardRunnerInfo, ShardRunnerOptions, + TransportCompression, }; #[cfg(feature = "cache")] use crate::cache::Cache; @@ -64,6 +65,8 @@ pub struct ShardQueuer { pub voice_manager: Option>, /// A copy of the URL to use to connect to the gateway. pub ws_url: Arc, + /// The compression method to use for the WebSocket connection. + pub compression: TransportCompression, /// The total amount of shards to start. pub shard_total: NonZeroU16, /// Number of seconds to wait between each start @@ -216,6 +219,7 @@ impl ShardQueuer { shard_info, self.intents, self.presence.clone(), + self.compression, ) .await?; diff --git a/src/gateway/sharding/shard_runner.rs b/src/gateway/sharding/shard_runner.rs index c8000c559f6..00b6f91763b 100644 --- a/src/gateway/sharding/shard_runner.rs +++ b/src/gateway/sharding/shard_runner.rs @@ -451,7 +451,16 @@ impl ShardRunner { )) => { error!("Shard handler received fatal err: {why:?}"); - self.manager.return_with_value(Err(why.clone())).await; + let why_clone = match why { + GatewayError::InvalidAuthentication => GatewayError::InvalidAuthentication, + GatewayError::InvalidGatewayIntents => GatewayError::InvalidGatewayIntents, + GatewayError::DisallowedGatewayIntents => { + GatewayError::DisallowedGatewayIntents + }, + _ => unreachable!(), + }; + + self.manager.return_with_value(Err(why_clone)).await; return Err(Error::Gateway(why)); }, Err(Error::Json(_)) => return Ok((None, None, true)), diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 9fd28e1a17f..d88a1b569f7 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -1,8 +1,11 @@ +use std::borrow::Cow; use std::env::consts; use std::io::Read; use std::time::SystemTime; use flate2::read::ZlibDecoder; +#[cfg(feature = "transport_compression_zlib")] +use flate2::Decompress as ZlibInflater; use futures::{SinkExt, StreamExt}; use small_fixed_array::FixedString; use tokio::net::TcpStream; @@ -12,8 +15,10 @@ use tokio_tungstenite::tungstenite::{Error as WsError, Message}; use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream}; use tracing::{debug, trace, warn}; use url::Url; +#[cfg(feature = "transport_compression_zstd")] +use zstd_safe::{DStream as ZstdInflater, InBuffer, OutBuffer}; -use super::{ActivityData, ChunkGuildFilter, GatewayError, PresenceData}; +use super::{ActivityData, ChunkGuildFilter, GatewayError, PresenceData, TransportCompression}; use crate::constants::{self, Opcode}; use crate::model::event::GatewayEvent; use crate::model::gateway::{GatewayIntents, ShardInfo}; @@ -75,13 +80,154 @@ struct WebSocketMessage<'a> { d: WebSocketMessageData<'a>, } -pub struct WsClient(WebSocketStream>); +enum Compression { + Payload { + decompressed: Vec, + }, + + #[cfg(feature = "transport_compression_zlib")] + Zlib { + inflater: ZlibInflater, + compressed: Vec, + decompressed: Box<[u8]>, + }, + + #[cfg(feature = "transport_compression_zstd")] + Zstd { + inflater: ZstdInflater<'static>, + decompressed: Box<[u8]>, + }, +} + +impl Compression { + #[cfg(any(feature = "transport_compression_zlib", feature = "transport_compression_zstd"))] + const DECOMPRESSED_CAPACITY: usize = 64 * 1024; + + fn inflate(&mut self, slice: &[u8]) -> Result> { + match self { + Compression::Payload { + decompressed, + } => { + const DECOMPRESSION_MULTIPLIER: usize = 3; + + decompressed.clear(); + decompressed.reserve(slice.len() * DECOMPRESSION_MULTIPLIER); + + ZlibDecoder::new(slice).read_to_end(decompressed).map_err(|why| { + warn!("Err decompressing bytes: {why:?}"); + debug!("Failing bytes: {slice:?}"); + + why + })?; + + Ok(Some(decompressed.as_slice())) + }, + + #[cfg(feature = "transport_compression_zlib")] + Compression::Zlib { + inflater, + compressed, + decompressed, + } => { + const ZLIB_SUFFIX: [u8; 4] = [0x00, 0x00, 0xFF, 0xFF]; + + compressed.extend_from_slice(slice); + let length = compressed.len(); + + if length < 4 || compressed[length - 4..] != ZLIB_SUFFIX { + return Ok(None); + } + + let pre_out = inflater.total_out(); + inflater + .decompress(compressed, decompressed, flate2::FlushDecompress::Sync) + .map_err(GatewayError::DecompressZlib)?; + compressed.clear(); + let produced = (inflater.total_out() - pre_out) as usize; + + Ok(Some(&decompressed[..produced])) + }, + + #[cfg(feature = "transport_compression_zstd")] + Compression::Zstd { + inflater, + decompressed, + } => { + let mut in_buffer = InBuffer::around(slice); + let mut out_buffer = OutBuffer::around(decompressed.as_mut()); + + let length = slice.len(); + let mut processed = 0; + + loop { + match inflater.decompress_stream(&mut out_buffer, &mut in_buffer) { + Ok(0) => break, + Ok(_hint) => {}, + Err(code) => { + return Err(Error::Gateway(GatewayError::DecompressZstd(code))) + }, + } + + let in_pos = in_buffer.pos(); + + let progressed = in_pos > processed; + let read_all_input = in_pos == length; + + if !progressed { + if read_all_input { + break; + } + + return Err(Error::Gateway(GatewayError::DecompressZstdCorrupted)); + } + + processed = in_pos; + } + + let produced = out_buffer.pos(); + Ok(Some(&decompressed[..produced])) + }, + } + } +} + +impl From for Compression { + fn from(value: TransportCompression) -> Self { + match value { + TransportCompression::None => Compression::Payload { + decompressed: Vec::new(), + }, + + #[cfg(feature = "transport_compression_zlib")] + TransportCompression::Zlib => Compression::Zlib { + inflater: ZlibInflater::new(true), + compressed: Vec::new(), + decompressed: vec![0; Self::DECOMPRESSED_CAPACITY].into_boxed_slice(), + }, + + #[cfg(feature = "transport_compression_zstd")] + TransportCompression::Zstd => { + let mut inflater = ZstdInflater::create(); + inflater.init().expect("Failed to initialize Zstd decompressor"); + + Compression::Zstd { + inflater, + decompressed: vec![0; Self::DECOMPRESSED_CAPACITY].into_boxed_slice(), + } + }, + } + } +} + +pub struct WsClient { + stream: WebSocketStream>, + compression: Compression, +} const TIMEOUT: Duration = Duration::from_millis(500); -const DECOMPRESSION_MULTIPLIER: usize = 3; impl WsClient { - pub(crate) async fn connect(url: Url) -> Result { + pub(crate) async fn connect(url: Url, compression: TransportCompression) -> Result { let config = WebSocketConfig { max_message_size: None, max_frame_size: None, @@ -89,30 +235,27 @@ impl WsClient { }; let (stream, _) = connect_async_with_config(url, Some(config), false).await?; - Ok(Self(stream)) + Ok(Self { + stream, + compression: compression.into(), + }) } pub(crate) async fn recv_json(&mut self) -> Result> { - let message = match timeout(TIMEOUT, self.0.next()).await { + let message = match timeout(TIMEOUT, self.stream.next()).await { Ok(Some(Ok(msg))) => msg, Ok(Some(Err(e))) => return Err(e.into()), Ok(None) | Err(_) => return Ok(None), }; - let json_str = match message { - Message::Text(payload) => payload, + let json_bytes = match message { + Message::Text(payload) => Cow::Owned(payload.into_bytes()), Message::Binary(bytes) => { - let mut decompressed = - String::with_capacity(bytes.len() * DECOMPRESSION_MULTIPLIER); - - ZlibDecoder::new(&bytes[..]).read_to_string(&mut decompressed).map_err(|why| { - warn!("Err decompressing bytes: {why:?}"); - debug!("Failing bytes: {bytes:?}"); - - why - })?; + let Some(decompressed) = self.compression.inflate(&bytes)? else { + return Ok(None); + }; - decompressed + Cow::Borrowed(decompressed) }, Message::Close(Some(frame)) => { return Err(Error::Gateway(GatewayError::Closed(Some(frame)))); @@ -120,19 +263,21 @@ impl WsClient { _ => return Ok(None), }; - match serde_json::from_str(&json_str) { + // TODO: Use `String::from_utf8_lossy_owned` when stable. + let json_str = || String::from_utf8_lossy(&json_bytes); + match serde_json::from_slice(&json_bytes) { Ok(mut event) => { if let GatewayEvent::Dispatch { original_str, .. } = &mut event { - *original_str = FixedString::from_string_trunc(json_str); + *original_str = FixedString::from_string_trunc(json_str().into_owned()); } Ok(Some(event)) }, Err(err) => { - debug!("Failing text: {json_str}"); + debug!("Failing text: {}", json_str()); Err(Error::Json(err)) }, } @@ -141,24 +286,24 @@ impl WsClient { pub(crate) async fn send_json(&mut self, value: &impl serde::Serialize) -> Result<()> { let message = serde_json::to_string(value).map(Message::Text)?; - self.0.send(message).await?; + self.stream.send(message).await?; Ok(()) } /// Delegate to `StreamExt::next` pub(crate) async fn next(&mut self) -> Option> { - self.0.next().await + self.stream.next().await } /// Delegate to `SinkExt::send` pub(crate) async fn send(&mut self, message: Message) -> Result<()> { - self.0.send(message).await?; + self.stream.send(message).await?; Ok(()) } /// Delegate to `WebSocketStream::close` pub(crate) async fn close(&mut self, msg: Option>) -> Result<()> { - self.0.close(msg).await?; + self.stream.close(msg).await?; Ok(()) } @@ -230,7 +375,7 @@ impl WsClient { token, shard, intents, - compress: true, + compress: matches!(self.compression, Compression::Payload { .. }), large_threshold: constants::LARGE_THRESHOLD, properties: IdentifyProperties { browser: "serenity",