From 18e2d9f336d0a17ae7f4f3a8f44c387068038baa Mon Sep 17 00:00:00 2001 From: Michael Krasnitski Date: Mon, 19 Feb 2024 11:20:57 -0500 Subject: [PATCH] Add `CreateAttachmentStream` type --- src/builder/create_attachment.rs | 160 ++++++++++++++++-- src/builder/create_forum_post.rs | 2 +- src/builder/create_interaction_response.rs | 2 +- .../create_interaction_response_followup.rs | 2 +- src/builder/create_message.rs | 2 +- src/builder/edit_interaction_response.rs | 5 +- src/builder/edit_message.rs | 5 +- src/builder/edit_webhook_message.rs | 5 +- src/builder/execute_webhook.rs | 2 +- 9 files changed, 162 insertions(+), 23 deletions(-) diff --git a/src/builder/create_attachment.rs b/src/builder/create_attachment.rs index 0a016b25b86..b7c97ab9f18 100644 --- a/src/builder/create_attachment.rs +++ b/src/builder/create_attachment.rs @@ -4,6 +4,7 @@ use std::path::Path; use serde::ser::{Serialize, SerializeSeq, Serializer}; use tokio::fs::File; use tokio::io::AsyncReadExt; +use url::Url; #[allow(unused)] // Error is used in docs use crate::error::{Error, Result}; @@ -111,14 +112,136 @@ impl<'a> CreateAttachment<'a> { } } +/// Streaming alternative to [`CreateAttachment`] that does not read data into memory until +/// consumed and sent as part of a request, at which point necessary disk reads or network requests +/// will be made. +/// +/// **Note**: Cloning this type does not clone its associated data - meaning, cloned attachment +/// streams will result in extra disk reads or network requests when being sent off. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct CreateAttachmentStream<'a> { + pub filename: Cow<'static, str>, + pub description: Option>, + pub kind: AttachmentStreamKind<'a>, +} + +#[derive(Clone, Debug)] +#[non_exhaustive] +pub enum AttachmentStreamKind<'a> { + Path(&'a Path), + File(&'a File), + Url(&'a Http, Url), +} + +impl<'a> CreateAttachmentStream<'a> { + /// Builds a [`CreateAttachmentStream`] by storing the path to a file. + /// + /// # Errors + /// + /// [`Error::Io`] if the path does not point to a file. + pub fn path(path: &'a Path) -> Result { + let filename = path + .file_name() + .ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::Other, + "attachment path must not be a directory", + ) + })? + .to_string_lossy() + .into_owned(); + Ok(CreateAttachmentStream { + filename: filename.into(), + description: None, + kind: AttachmentStreamKind::Path(path), + }) + } + + /// Builds a [`CreateAttachmentStream`] by storing a file handler. + pub fn file(file: &'a File, filename: impl Into>) -> Self { + CreateAttachmentStream { + filename: filename.into(), + description: None, + kind: AttachmentStreamKind::File(file), + } + } + + /// Builds an [`CreateAttachmentStream`] by storing a URL. + /// + /// # Errors + /// + /// Returns [`Error::Http`] if the url is not valid. + pub fn url( + http: &'a Http, + url: impl reqwest::IntoUrl, + filename: impl Into>, + ) -> Result { + Ok(CreateAttachmentStream { + filename: filename.into(), + description: None, + kind: AttachmentStreamKind::Url(http, url.into_url()?), + }) + } + + /// Sets a description for the file (max 1024 characters). + pub fn description(mut self, description: impl Into>) -> Self { + self.description = Some(description.into()); + self + } + + /// Reads the attachment data either from disk or from the network. + async fn data(&self) -> Result> { + match &self.kind { + AttachmentStreamKind::Path(path) => { + let mut file = File::open(path).await?; + let mut buf = Vec::new(); + file.read_to_end(&mut buf).await?; + Ok(buf) + }, + AttachmentStreamKind::File(file) => { + let mut buf = Vec::new(); + file.try_clone().await?.read_to_end(&mut buf).await?; + Ok(buf) + }, + AttachmentStreamKind::Url(http, url) => { + let response = http.client.get(url.clone()).send().await?; + Ok(response.bytes().await?.to_vec()) + }, + } + } +} + #[derive(Clone, Debug, Serialize)] struct ExistingAttachment { id: AttachmentId, } +#[derive(Clone, Debug)] +enum NewAttachment<'a> { + Bytes(CreateAttachment<'a>), + Stream(CreateAttachmentStream<'a>), +} + +impl NewAttachment<'_> { + fn filename(&self) -> &Cow<'static, str> { + match self { + NewAttachment::Bytes(attachment) => &attachment.filename, + NewAttachment::Stream(attachment) => &attachment.filename, + } + } + + fn description(&self) -> &Option> { + match self { + NewAttachment::Bytes(attachment) => &attachment.description, + NewAttachment::Stream(attachment) => &attachment.description, + } + } +} + #[derive(Clone, Debug)] enum NewOrExisting<'a> { - New(CreateAttachment<'a>), + New(NewAttachment<'a>), Existing(ExistingAttachment), } @@ -243,33 +366,40 @@ impl<'a> EditAttachments<'a> { /// Adds a new attachment to the attachment list. #[allow(clippy::should_implement_trait)] // Clippy thinks add == std::ops::Add::add pub fn add(mut self, attachment: CreateAttachment<'a>) -> Self { - self.new_and_existing_attachments.push(NewOrExisting::New(attachment)); + self.new_and_existing_attachments + .push(NewOrExisting::New(NewAttachment::Bytes(attachment))); + self + } + + /// Adds a new attachment stream to the attachment list. + pub fn add_stream(mut self, attachment: CreateAttachmentStream<'a>) -> Self { + self.new_and_existing_attachments + .push(NewOrExisting::New(NewAttachment::Stream(attachment))); self } /// Clones all new attachments into a new Vec, keeping only data and filename, because those /// are needed for the multipart form data. The data is taken out of `self` in the process, so /// this method can only be called once. - pub(crate) fn take_files(&mut self) -> Vec> { + pub(crate) async fn take_files(&mut self) -> Result>> { let mut files = Vec::new(); for attachment in &mut self.new_and_existing_attachments { - if let NewOrExisting::New(attachment) = attachment { - let cloned_attachment = CreateAttachment::bytes( - std::mem::take(&mut attachment.data), - attachment.filename.clone(), - ); - - files.push(cloned_attachment); + if let NewOrExisting::New(new_attachment) = attachment { + let data = match new_attachment { + NewAttachment::Bytes(attachment) => std::mem::take(&mut attachment.data), + NewAttachment::Stream(attachment) => attachment.data().await?.into(), + }; + files.push(CreateAttachment::bytes(data, new_attachment.filename().clone())) } } - files + Ok(files) } } impl<'a> Serialize for EditAttachments<'a> { fn serialize(&self, serializer: S) -> Result { #[derive(Serialize)] - struct NewAttachment<'a> { + struct AttachmentMetadata<'a> { id: u64, filename: &'a Cow<'static, str>, description: &'a Option>, @@ -283,10 +413,10 @@ impl<'a> Serialize for EditAttachments<'a> { for attachment in &self.new_and_existing_attachments { match attachment { NewOrExisting::New(new_attachment) => { - let attachment = NewAttachment { + let attachment = AttachmentMetadata { id, - filename: &new_attachment.filename, - description: &new_attachment.description, + filename: &new_attachment.filename(), + description: &new_attachment.description(), }; id += 1; seq.serialize_element(&attachment)?; diff --git a/src/builder/create_forum_post.rs b/src/builder/create_forum_post.rs index 5cf358b083b..e61ea6454be 100644 --- a/src/builder/create_forum_post.rs +++ b/src/builder/create_forum_post.rs @@ -99,7 +99,7 @@ impl<'a> CreateForumPost<'a> { /// Returns [`Error::Http`] if the current user lacks permission, or if invalid data is given. #[cfg(feature = "http")] pub async fn execute(mut self, http: &Http, channel_id: ChannelId) -> Result { - let files = self.message.attachments.take_files(); + let files = self.message.attachments.take_files().await?; http.create_forum_post(channel_id, &self, files, self.audit_log_reason).await } } diff --git a/src/builder/create_interaction_response.rs b/src/builder/create_interaction_response.rs index ac72f67a748..bd84358ddb0 100644 --- a/src/builder/create_interaction_response.rs +++ b/src/builder/create_interaction_response.rs @@ -121,7 +121,7 @@ impl CreateInteractionResponse<'_> { let files = match &mut self { CreateInteractionResponse::Message(msg) | CreateInteractionResponse::Defer(msg) - | CreateInteractionResponse::UpdateMessage(msg) => msg.attachments.take_files(), + | CreateInteractionResponse::UpdateMessage(msg) => msg.attachments.take_files().await?, _ => Vec::new(), }; diff --git a/src/builder/create_interaction_response_followup.rs b/src/builder/create_interaction_response_followup.rs index 5bef0096c16..35396441fa0 100644 --- a/src/builder/create_interaction_response_followup.rs +++ b/src/builder/create_interaction_response_followup.rs @@ -167,7 +167,7 @@ impl<'a> CreateInteractionResponseFollowup<'a> { ) -> Result { self.check_length()?; - let files = self.attachments.take_files(); + let files = self.attachments.take_files().await?; if self.allowed_mentions.is_none() { self.allowed_mentions.clone_from(&http.default_allowed_mentions); diff --git a/src/builder/create_message.rs b/src/builder/create_message.rs index 6047f019b11..13ae3b056e5 100644 --- a/src/builder/create_message.rs +++ b/src/builder/create_message.rs @@ -295,7 +295,7 @@ impl<'a> CreateMessage<'a> { ) -> Result { self.check_length()?; - let files = self.attachments.take_files(); + let files = self.attachments.take_files().await?; if self.allowed_mentions.is_none() { self.allowed_mentions.clone_from(&http.default_allowed_mentions); } diff --git a/src/builder/edit_interaction_response.rs b/src/builder/edit_interaction_response.rs index ea3172bb1c3..2ee6694d339 100644 --- a/src/builder/edit_interaction_response.rs +++ b/src/builder/edit_interaction_response.rs @@ -114,7 +114,10 @@ impl<'a> EditInteractionResponse<'a> { pub async fn execute(mut self, http: &Http, interaction_token: &str) -> Result { self.0.check_length()?; - let files = self.0.attachments.as_mut().map_or(Vec::new(), EditAttachments::take_files); + let files = match self.0.attachments.as_mut() { + Some(attachments) => attachments.take_files().await?, + None => Vec::new(), + }; http.edit_original_interaction_response(interaction_token, &self, files).await } diff --git a/src/builder/edit_message.rs b/src/builder/edit_message.rs index ae0e735f3bc..8a9257b4e1f 100644 --- a/src/builder/edit_message.rs +++ b/src/builder/edit_message.rs @@ -252,7 +252,10 @@ impl<'a> EditMessage<'a> { } } - let files = self.attachments.as_mut().map_or(Vec::new(), EditAttachments::take_files); + let files = match self.attachments.as_mut() { + Some(attachments) => attachments.take_files().await?, + None => Vec::new(), + }; let http = cache_http.http(); if self.allowed_mentions.is_none() { diff --git a/src/builder/edit_webhook_message.rs b/src/builder/edit_webhook_message.rs index bbed86fe17d..6662f6bb11f 100644 --- a/src/builder/edit_webhook_message.rs +++ b/src/builder/edit_webhook_message.rs @@ -164,7 +164,10 @@ impl<'a> EditWebhookMessage<'a> { ) -> Result { self.check_length()?; - let files = self.attachments.as_mut().map_or(Vec::new(), EditAttachments::take_files); + let files = match self.attachments.as_mut() { + Some(attachments) => attachments.take_files().await?, + None => Vec::new(), + }; if self.allowed_mentions.is_none() { self.allowed_mentions.clone_from(&http.default_allowed_mentions); diff --git a/src/builder/execute_webhook.rs b/src/builder/execute_webhook.rs index 02259b8d95d..5f260e89eef 100644 --- a/src/builder/execute_webhook.rs +++ b/src/builder/execute_webhook.rs @@ -342,7 +342,7 @@ impl<'a> ExecuteWebhook<'a> { ) -> Result> { self.check_length()?; - let files = self.attachments.take_files(); + let files = self.attachments.take_files().await?; if self.allowed_mentions.is_none() { self.allowed_mentions.clone_from(&http.default_allowed_mentions);