Skip to content

Commit

Permalink
Transport compression support (serenity-rs#3036)
Browse files Browse the repository at this point in the history
Co-authored-by: GnomedDev <[email protected]>
  • Loading branch information
fgardt and GnomedDev authored Nov 18, 2024
1 parent 2e12663 commit b20151f
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 43 deletions.
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ bytes = "1.5.0"
fxhash = { version = "0.2.1", optional = true }
chrono = { version = "0.4.31", default-features = false, features = ["clock", "serde"], optional = true }
flate2 = { version = "1.0.28", optional = true }
zstd-safe = { version = "7.2.1", optional = true }
reqwest = { version = "0.12.2", default-features = false, features = ["multipart", "stream", "json"], optional = true }
tokio-tungstenite = { version = "0.24.0", features = ["url"], optional = true }
percent-encoding = { version = "2.3.0", optional = true }
Expand Down Expand Up @@ -70,6 +71,8 @@ default_no_backend = [
"cache",
"chrono",
"framework",
"transport_compression_zlib",
"transport_compression_zstd",
]

# Enables builder structs to configure Discord HTTP requests. Without this feature, you have to
Expand All @@ -93,6 +96,10 @@ http = ["dashmap", "mime_guess", "percent-encoding"]
# TODO: remove dependeny on utils feature
model = ["builder", "http", "utils"]
voice_model = ["serenity-voice-model"]
# Enables zlib-stream transport compression of incoming gateway events.
transport_compression_zlib = ["flate2", "gateway"]
# Enables zstd-stream transport compression of incoming gateway events.
transport_compression_zstd = ["zstd-safe", "gateway"]
# Enables support for Discord API functionality that's not stable yet, as well as serenity APIs that
# are allowed to change even in semver non-breaking updates.
unstable = []
Expand Down
10 changes: 10 additions & 0 deletions src/gateway/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use tracing::{debug, warn};

pub use self::context::Context;
pub use self::event_handler::{EventHandler, FullEvent, RawEventHandler};
use super::TransportCompression;
#[cfg(feature = "cache")]
use crate::cache::Cache;
#[cfg(feature = "cache")]
Expand Down Expand Up @@ -82,6 +83,7 @@ pub struct ClientBuilder {
raw_event_handler: Option<Arc<dyn RawEventHandler>>,
presence: PresenceData,
wait_time_between_shard_start: Duration,
compression: TransportCompression,
}

impl ClientBuilder {
Expand Down Expand Up @@ -116,6 +118,7 @@ impl ClientBuilder {
raw_event_handler: None,
presence: PresenceData::default(),
wait_time_between_shard_start: DEFAULT_WAIT_BETWEEN_SHARD_START,
compression: TransportCompression::None,
}
}

Expand Down Expand Up @@ -176,6 +179,12 @@ impl ClientBuilder {
self
}

/// Sets the compression method to be used when receiving data from the gateway.
pub fn compression(mut self, compression: TransportCompression) -> Self {
self.compression = compression;
self
}

/// Sets the voice gateway handler to be used. It will receive voice events sent over the
/// gateway and then consider - based on its settings - whether to dispatch a command.
#[cfg(feature = "voice")]
Expand Down Expand Up @@ -342,6 +351,7 @@ impl IntoFuture for ClientBuilder {
presence: Some(presence),
max_concurrency,
wait_time_between_shard_start: self.wait_time_between_shard_start,
compression: self.compression,
});

let client = Client {
Expand Down
22 changes: 21 additions & 1 deletion src/gateway/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tokio_tungstenite::tungstenite::protocol::CloseFrame;
///
/// Note that - from a user standpoint - there should be no situation in which you manually handle
/// these.
#[derive(Clone, Debug)]
#[derive(Debug)]
#[non_exhaustive]
pub enum Error {
/// There was an error building a URL.
Expand Down Expand Up @@ -50,6 +50,17 @@ pub enum Error {
/// If an connection has been established but privileged gateway intents were provided without
/// enabling them prior.
DisallowedGatewayIntents,
#[cfg(feature = "transport_compression_zlib")]
/// A decompression error from the `flate2` crate.
DecompressZlib(flate2::DecompressError),
#[cfg(feature = "transport_compression_zstd")]
/// A decompression error from zstd.
DecompressZstd(usize),
/// When zstd decompression fails due to corrupted data.
#[cfg(feature = "transport_compression_zstd")]
DecompressZstdCorrupted,
/// When decompressed gateway data is not valid UTF-8.
DecompressUtf8(std::string::FromUtf8Error),
}

impl fmt::Display for Error {
Expand All @@ -70,6 +81,15 @@ impl fmt::Display for Error {
Self::DisallowedGatewayIntents => {
f.write_str("Disallowed gateway intents were provided")
},
#[cfg(feature = "transport_compression_zlib")]
Self::DecompressZlib(inner) => fmt::Display::fmt(&inner, f),
#[cfg(feature = "transport_compression_zstd")]
Self::DecompressZstd(code) => write!(f, "Zstd decompression error: {code}"),
#[cfg(feature = "transport_compression_zstd")]
Self::DecompressZstdCorrupted => {
f.write_str("Zstd decompression error: corrupted data")
},
Self::DecompressUtf8(inner) => fmt::Display::fmt(&inner, f),
}
}
}
Expand Down
82 changes: 69 additions & 13 deletions src/gateway/sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ use std::fmt;
use std::sync::Arc;
use std::time::{Duration as StdDuration, Instant};

use aformat::{aformat, CapStr};
#[cfg(feature = "transport_compression_zlib")]
use aformat::aformat_into;
use aformat::{aformat, ArrayString, CapStr};
use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
use tracing::{debug, error, info, trace, warn};
Expand Down Expand Up @@ -113,6 +115,7 @@ pub struct Shard {
token: SecretString,
ws_url: Arc<str>,
resume_ws_url: Option<FixedString>,
compression: TransportCompression,
pub intents: GatewayIntents,
}

Expand All @@ -129,7 +132,7 @@ impl Shard {
/// use std::num::NonZeroU16;
/// use std::sync::Arc;
///
/// use serenity::gateway::Shard;
/// use serenity::gateway::{Shard, TransportCompression};
/// use serenity::model::gateway::{GatewayIntents, ShardInfo};
/// use serenity::model::id::ShardId;
/// use serenity::secret_string::SecretString;
Expand All @@ -147,7 +150,15 @@ impl Shard {
///
/// // retrieve the gateway response, which contains the URL to connect to
/// let gateway = Arc::from(http.get_gateway().await?.url);
/// let shard = Shard::new(gateway, token, shard_info, GatewayIntents::all(), None).await?;
/// let shard = Shard::new(
/// gateway,
/// token,
/// shard_info,
/// GatewayIntents::all(),
/// None,
/// TransportCompression::None,
/// )
/// .await?;
///
/// // at this point, you can create a `loop`, and receive events and match
/// // their variants
Expand All @@ -165,8 +176,9 @@ impl Shard {
shard_info: ShardInfo,
intents: GatewayIntents,
presence: Option<PresenceData>,
compression: TransportCompression,
) -> Result<Shard> {
let client = connect(&ws_url).await?;
let client = connect(&ws_url, compression).await?;

let presence = presence.unwrap_or_default();
let last_heartbeat_sent = None;
Expand All @@ -193,6 +205,7 @@ impl Shard {
shard_info,
ws_url,
resume_ws_url: None,
compression,
intents,
})
}
Expand Down Expand Up @@ -748,7 +761,7 @@ impl Shard {
// Hello is received.
self.stage = ConnectionStage::Connecting;
self.started = Instant::now();
let client = connect(ws_url).await?;
let client = connect(ws_url, self.compression).await?;
self.stage = ConnectionStage::Handshake;

Ok(client)
Expand Down Expand Up @@ -807,14 +820,19 @@ impl Shard {
}
}

async fn connect(base_url: &str) -> Result<WsClient> {
let url = Url::parse(&aformat!("{}?v={}", CapStr::<64>(base_url), constants::GATEWAY_VERSION))
.map_err(|why| {
warn!("Error building gateway URL with base `{base_url}`: {why:?}");
Error::Gateway(GatewayError::BuildingUrl)
})?;

WsClient::connect(url).await
async fn connect(base_url: &str, compression: TransportCompression) -> Result<WsClient> {
let url = Url::parse(&aformat!(
"{}?v={}{}",
CapStr::<64>(base_url),
constants::GATEWAY_VERSION,
compression.query_param()
))
.map_err(|why| {
warn!("Error building gateway URL with base `{base_url}`: {why:?}");
Error::Gateway(GatewayError::BuildingUrl)
})?;

WsClient::connect(url, compression).await
}

#[derive(Debug)]
Expand Down Expand Up @@ -954,3 +972,41 @@ impl PartialEq for CollectorCallback {
Arc::ptr_eq(&self.0, &other.0)
}
}

/// The transport compression method to use.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum TransportCompression {
/// No transport compression. Payload compression will be used instead.
None,

#[cfg(feature = "transport_compression_zlib")]
/// Use zlib-stream transport compression.
Zlib,

#[cfg(feature = "transport_compression_zstd")]
/// Use zstd-stream transport compression.
Zstd,
}

impl TransportCompression {
fn query_param(self) -> ArrayString<21> {
#[cfg_attr(
not(any(
feature = "transport_compression_zlib",
feature = "transport_compression_zstd"
)),
expect(unused_mut)
)]
let mut res = ArrayString::new();
match self {
Self::None => {},
#[cfg(feature = "transport_compression_zlib")]
Self::Zlib => aformat_into!(res, "&compress=zlib-stream"),
#[cfg(feature = "transport_compression_zstd")]
Self::Zstd => aformat_into!(res, "&compress=zstd-stream"),
}

res
}
}
19 changes: 17 additions & 2 deletions src/gateway/sharding/shard_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ use tokio::sync::Mutex;
use tokio::time::timeout;
use tracing::{info, warn};

use super::{ShardId, ShardQueue, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo};
use super::{
ShardId,
ShardQueue,
ShardQueuer,
ShardQueuerMessage,
ShardRunnerInfo,
TransportCompression,
};
#[cfg(feature = "cache")]
use crate::cache::Cache;
#[cfg(feature = "framework")]
Expand Down Expand Up @@ -53,7 +60,12 @@ pub const DEFAULT_WAIT_BETWEEN_SHARD_START: Duration = Duration::from_secs(5);
/// use std::sync::{Arc, OnceLock};
///
/// use serenity::gateway::client::EventHandler;
/// use serenity::gateway::{ShardManager, ShardManagerOptions, DEFAULT_WAIT_BETWEEN_SHARD_START};
/// use serenity::gateway::{
/// ShardManager,
/// ShardManagerOptions,
/// TransportCompression,
/// DEFAULT_WAIT_BETWEEN_SHARD_START,
/// };
/// use serenity::http::Http;
/// use serenity::model::gateway::GatewayIntents;
/// use serenity::prelude::*;
Expand Down Expand Up @@ -88,6 +100,7 @@ pub const DEFAULT_WAIT_BETWEEN_SHARD_START: Duration = Duration::from_secs(5);
/// presence: None,
/// max_concurrency,
/// wait_time_between_shard_start: DEFAULT_WAIT_BETWEEN_SHARD_START,
/// compression: TransportCompression::None,
/// });
/// # Ok(())
/// # }
Expand Down Expand Up @@ -144,6 +157,7 @@ impl ShardManager {
#[cfg(feature = "voice")]
voice_manager: opt.voice_manager,
ws_url: opt.ws_url,
compression: opt.compression,
shard_total: opt.shard_total,
#[cfg(feature = "cache")]
cache: opt.cache,
Expand Down Expand Up @@ -379,4 +393,5 @@ pub struct ShardManagerOptions {
pub max_concurrency: NonZeroU16,
/// Number of seconds to wait between starting each shard/set of shards start
pub wait_time_between_shard_start: Duration,
pub compression: TransportCompression,
}
4 changes: 4 additions & 0 deletions src/gateway/sharding/shard_queuer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use super::{
ShardRunner,
ShardRunnerInfo,
ShardRunnerOptions,
TransportCompression,
};
#[cfg(feature = "cache")]
use crate::cache::Cache;
Expand Down Expand Up @@ -64,6 +65,8 @@ pub struct ShardQueuer {
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
/// A copy of the URL to use to connect to the gateway.
pub ws_url: Arc<str>,
/// The compression method to use for the WebSocket connection.
pub compression: TransportCompression,
/// The total amount of shards to start.
pub shard_total: NonZeroU16,
/// Number of seconds to wait between each start
Expand Down Expand Up @@ -216,6 +219,7 @@ impl ShardQueuer {
shard_info,
self.intents,
self.presence.clone(),
self.compression,
)
.await?;

Expand Down
11 changes: 10 additions & 1 deletion src/gateway/sharding/shard_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,16 @@ impl ShardRunner {
)) => {
error!("Shard handler received fatal err: {why:?}");

self.manager.return_with_value(Err(why.clone())).await;
let why_clone = match why {
GatewayError::InvalidAuthentication => GatewayError::InvalidAuthentication,
GatewayError::InvalidGatewayIntents => GatewayError::InvalidGatewayIntents,
GatewayError::DisallowedGatewayIntents => {
GatewayError::DisallowedGatewayIntents
},
_ => unreachable!(),
};

self.manager.return_with_value(Err(why_clone)).await;
return Err(Error::Gateway(why));
},
Err(Error::Json(_)) => return Ok((None, None, true)),
Expand Down
Loading

0 comments on commit b20151f

Please sign in to comment.