Skip to content

Commit

Permalink
Add CreateAttachmentStream type
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrasnitski committed Aug 16, 2024
1 parent b167b4d commit d6a5271
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 23 deletions.
160 changes: 145 additions & 15 deletions src/builder/create_attachment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::path::Path;
use serde::ser::{Serialize, SerializeSeq, Serializer};
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use url::Url;

#[allow(unused)] // Error is used in docs
use crate::error::{Error, Result};
Expand Down Expand Up @@ -111,14 +112,136 @@ impl<'a> CreateAttachment<'a> {
}
}

/// Streaming alternative to [`CreateAttachment`] that does not read data into memory until
/// consumed and sent as part of a request, at which point necessary disk reads or network requests
/// will be made.
///
/// **Note**: Cloning this type does not clone its associated data - meaning, cloned attachment
/// streams will result in extra disk reads or network requests when being sent off.
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct CreateAttachmentStream<'a> {
pub filename: Cow<'static, str>,
pub description: Option<Cow<'a, str>>,
pub kind: AttachmentStreamKind<'a>,
}

#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum AttachmentStreamKind<'a> {
Path(&'a Path),
File(&'a File),
Url(&'a Http, Url),
}

impl<'a> CreateAttachmentStream<'a> {
/// Builds a [`CreateAttachmentStream`] by storing the path to a file.
///
/// # Errors
///
/// [`Error::Io`] if the path does not point to a file.
pub fn path(path: &'a Path) -> Result<Self> {
let filename = path
.file_name()
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::Other,
"attachment path must not be a directory",
)
})?
.to_string_lossy()
.into_owned();
Ok(CreateAttachmentStream {
filename: filename.into(),
description: None,
kind: AttachmentStreamKind::Path(path),
})
}

/// Builds a [`CreateAttachmentStream`] by storing a file handler.
pub fn file(file: &'a File, filename: impl Into<Cow<'static, str>>) -> Self {
CreateAttachmentStream {
filename: filename.into(),
description: None,
kind: AttachmentStreamKind::File(file),
}
}

/// Builds an [`CreateAttachmentStream`] by storing a URL.
///
/// # Errors
///
/// Returns [`Error::Http`] if the url is not valid.
pub fn url(
http: &'a Http,
url: impl reqwest::IntoUrl,
filename: impl Into<Cow<'static, str>>,
) -> Result<Self> {
Ok(CreateAttachmentStream {
filename: filename.into(),
description: None,
kind: AttachmentStreamKind::Url(http, url.into_url()?),
})
}

/// Sets a description for the file (max 1024 characters).
pub fn description(mut self, description: impl Into<Cow<'a, str>>) -> Self {
self.description = Some(description.into());
self
}

/// Reads the attachment data either from disk or from the network.
async fn data(&self) -> Result<Vec<u8>> {
match &self.kind {
AttachmentStreamKind::Path(path) => {
let mut file = File::open(path).await?;
let mut buf = Vec::new();
file.read_to_end(&mut buf).await?;
Ok(buf)
},
AttachmentStreamKind::File(file) => {
let mut buf = Vec::new();
file.try_clone().await?.read_to_end(&mut buf).await?;
Ok(buf)
},
AttachmentStreamKind::Url(http, url) => {
let response = http.client.get(url.clone()).send().await?;
Ok(response.bytes().await?.to_vec())
},
}
}
}

#[derive(Clone, Debug, Serialize)]
struct ExistingAttachment {
id: AttachmentId,
}

#[derive(Clone, Debug)]
enum NewAttachment<'a> {
Bytes(CreateAttachment<'a>),
Stream(CreateAttachmentStream<'a>),
}

impl NewAttachment<'_> {
fn filename(&self) -> &Cow<'static, str> {
match self {
NewAttachment::Bytes(attachment) => &attachment.filename,
NewAttachment::Stream(attachment) => &attachment.filename,
}
}

fn description(&self) -> &Option<Cow<'_, str>> {
match self {
NewAttachment::Bytes(attachment) => &attachment.description,
NewAttachment::Stream(attachment) => &attachment.description,
}
}
}

#[derive(Clone, Debug)]
enum NewOrExisting<'a> {
New(CreateAttachment<'a>),
New(NewAttachment<'a>),
Existing(ExistingAttachment),
}

Expand Down Expand Up @@ -243,33 +366,40 @@ impl<'a> EditAttachments<'a> {
/// Adds a new attachment to the attachment list.
#[allow(clippy::should_implement_trait)] // Clippy thinks add == std::ops::Add::add
pub fn add(mut self, attachment: CreateAttachment<'a>) -> Self {
self.new_and_existing_attachments.push(NewOrExisting::New(attachment));
self.new_and_existing_attachments
.push(NewOrExisting::New(NewAttachment::Bytes(attachment)));
self
}

/// Adds a new attachment stream to the attachment list.
pub fn add_stream(mut self, attachment: CreateAttachmentStream<'a>) -> Self {
self.new_and_existing_attachments
.push(NewOrExisting::New(NewAttachment::Stream(attachment)));
self
}

/// Clones all new attachments into a new Vec, keeping only data and filename, because those
/// are needed for the multipart form data. The data is taken out of `self` in the process, so
/// this method can only be called once.
pub(crate) fn take_files(&mut self) -> Vec<CreateAttachment<'a>> {
pub(crate) async fn take_files(&mut self) -> Result<Vec<CreateAttachment<'a>>> {
let mut files = Vec::new();
for attachment in &mut self.new_and_existing_attachments {
if let NewOrExisting::New(attachment) = attachment {
let cloned_attachment = CreateAttachment::bytes(
std::mem::take(&mut attachment.data),
attachment.filename.clone(),
);

files.push(cloned_attachment);
if let NewOrExisting::New(new_attachment) = attachment {
let data = match new_attachment {
NewAttachment::Bytes(attachment) => std::mem::take(&mut attachment.data),
NewAttachment::Stream(attachment) => attachment.data().await?.into(),
};
files.push(CreateAttachment::bytes(data, new_attachment.filename().clone()))
}
}
files
Ok(files)
}
}

impl<'a> Serialize for EditAttachments<'a> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
#[derive(Serialize)]
struct NewAttachment<'a> {
struct AttachmentMetadata<'a> {
id: u64,
filename: &'a Cow<'static, str>,
description: &'a Option<Cow<'a, str>>,
Expand All @@ -283,10 +413,10 @@ impl<'a> Serialize for EditAttachments<'a> {
for attachment in &self.new_and_existing_attachments {
match attachment {
NewOrExisting::New(new_attachment) => {
let attachment = NewAttachment {
let attachment = AttachmentMetadata {
id,
filename: &new_attachment.filename,
description: &new_attachment.description,
filename: &new_attachment.filename(),
description: &new_attachment.description(),
};
id += 1;
seq.serialize_element(&attachment)?;
Expand Down
2 changes: 1 addition & 1 deletion src/builder/create_forum_post.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl<'a> CreateForumPost<'a> {
/// Returns [`Error::Http`] if the current user lacks permission, or if invalid data is given.
#[cfg(feature = "http")]
pub async fn execute(mut self, http: &Http, channel_id: ChannelId) -> Result<GuildChannel> {
let files = self.message.attachments.take_files();
let files = self.message.attachments.take_files().await?;
http.create_forum_post(channel_id, &self, files, self.audit_log_reason).await
}
}
2 changes: 1 addition & 1 deletion src/builder/create_interaction_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl CreateInteractionResponse<'_> {
let files = match &mut self {
CreateInteractionResponse::Message(msg)
| CreateInteractionResponse::Defer(msg)
| CreateInteractionResponse::UpdateMessage(msg) => msg.attachments.take_files(),
| CreateInteractionResponse::UpdateMessage(msg) => msg.attachments.take_files().await?,
_ => Vec::new(),
};

Expand Down
2 changes: 1 addition & 1 deletion src/builder/create_interaction_response_followup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ impl<'a> CreateInteractionResponseFollowup<'a> {
) -> Result<Message> {
self.check_length()?;

let files = self.attachments.take_files();
let files = self.attachments.take_files().await?;

if self.allowed_mentions.is_none() {
self.allowed_mentions.clone_from(&http.default_allowed_mentions);
Expand Down
2 changes: 1 addition & 1 deletion src/builder/create_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ impl<'a> CreateMessage<'a> {
) -> Result<Message> {
self.check_length()?;

let files = self.attachments.take_files();
let files = self.attachments.take_files().await?;
if self.allowed_mentions.is_none() {
self.allowed_mentions.clone_from(&http.default_allowed_mentions);
}
Expand Down
5 changes: 4 additions & 1 deletion src/builder/edit_interaction_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ impl<'a> EditInteractionResponse<'a> {
pub async fn execute(mut self, http: &Http, interaction_token: &str) -> Result<Message> {
self.0.check_length()?;

let files = self.0.attachments.as_mut().map_or(Vec::new(), EditAttachments::take_files);
let files = match self.0.attachments.as_mut() {
Some(attachments) => attachments.take_files().await?,
None => Vec::new(),
};

http.edit_original_interaction_response(interaction_token, &self, files).await
}
Expand Down
5 changes: 4 additions & 1 deletion src/builder/edit_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,10 @@ impl<'a> EditMessage<'a> {
}
}

let files = self.attachments.as_mut().map_or(Vec::new(), EditAttachments::take_files);
let files = match self.attachments.as_mut() {
Some(attachments) => attachments.take_files().await?,
None => Vec::new(),
};

let http = cache_http.http();
if self.allowed_mentions.is_none() {
Expand Down
5 changes: 4 additions & 1 deletion src/builder/edit_webhook_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ impl<'a> EditWebhookMessage<'a> {
) -> Result<Message> {
self.check_length()?;

let files = self.attachments.as_mut().map_or(Vec::new(), EditAttachments::take_files);
let files = match self.attachments.as_mut() {
Some(attachments) => attachments.take_files().await?,
None => Vec::new(),
};

if self.allowed_mentions.is_none() {
self.allowed_mentions.clone_from(&http.default_allowed_mentions);
Expand Down
2 changes: 1 addition & 1 deletion src/builder/execute_webhook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ impl<'a> ExecuteWebhook<'a> {
) -> Result<Option<Message>> {
self.check_length()?;

let files = self.attachments.take_files();
let files = self.attachments.take_files().await?;

if self.allowed_mentions.is_none() {
self.allowed_mentions.clone_from(&http.default_allowed_mentions);
Expand Down

0 comments on commit d6a5271

Please sign in to comment.