diff --git a/CHANGELOG.md b/CHANGELOG.md index 7197853..50870fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,11 +3,14 @@ ## [UNRELEASED] ### Added -- (pull/9) Handle errors for `ParamsLoader` and `ContextLoader` implementations +- (pull/25) Use `cargo nextest` for GitHub CI jobs +- (pull/20) Added `FileChannel`, `StreamChannel`, `BufferedStreamChannel` implementing `MessageWriterChannel` +- (pull/20) Added a simplified `PipesDefaultMessageWriter` implementing `MessageWriter` +- (pull/20) Defined `MessageWriter` and the associated `MessageWriterChannel` traits - (pull/14) Derived `PartialEq` for all types generated by `quicktype` - (pull/14) Renamed `ParamsLoader` and `ContextLoader` traits to `LoadParams` and `LoadContext` respectively - (pull/14) Fixed failing unit tests in `context_loader.rs` -- (pull/25) Use `cargo nextest` for GitHub CI jobs +- (pull/9) Handle errors for `ParamsLoader` and `ContextLoader` implementations ## 0.1.6 diff --git a/Cargo.lock b/Cargo.lock index 3227d94..40f66b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,15 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "base64" version = "0.22.1" @@ -41,6 +50,7 @@ version = "0.1.6" dependencies = [ "base64", "flate2", + "rstest", "serde", "serde_json", "tempfile", @@ -73,6 +83,12 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "itoa" version = "1.0.14" @@ -130,6 +146,77 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + +[[package]] +name = "rstest" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a2c585be59b6b5dd66a9d2084aa1d8bd52fbdb806eafdeffb52791147862035" +dependencies = [ + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "825ea780781b15345a146be27eaefb05085e337e869bff01b4306a4fd4a9ad5a" +dependencies = [ + "cfg-if", + "glob", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", +] + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.38.41" @@ -149,6 +236,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + [[package]] name = "serde" version = "1.0.215" diff --git a/Cargo.toml b/Cargo.toml index 64465a5..4489f73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,4 +13,5 @@ flate2 = "1" thiserror = "2.0.3" [dev-dependencies] +rstest = { version = "0.23.0", default-features = false } tempfile = "3.14.0" diff --git a/example-dagster-pipes-rust-project/rust_processing_jobs/Cargo.lock b/example-dagster-pipes-rust-project/rust_processing_jobs/Cargo.lock index d964449..ee30919 100644 --- a/example-dagster-pipes-rust-project/rust_processing_jobs/Cargo.lock +++ b/example-dagster-pipes-rust-project/rust_processing_jobs/Cargo.lock @@ -31,7 +31,7 @@ dependencies = [ [[package]] name = "dagster_pipes_rust" -version = "0.1.5" +version = "0.1.6" dependencies = [ "base64", "flate2", diff --git a/quicktype.sh b/quicktype.sh index f8cc214..ca8d21f 100755 --- a/quicktype.sh +++ b/quicktype.sh @@ -1,4 +1,6 @@ #!bash printf -- '%s\0' jsonschema/pipes/*.schema.json | xargs -0 \ - quicktype -s schema -l rust --visibility public --derive-debug --derive-partial-eq -o src/types.rs + quicktype -s schema -l rust \ + --visibility public --derive-debug --derive-clone --derive-partial-eq -o \ + src/types.rs diff --git a/src/lib.rs b/src/lib.rs index 7a22913..b660433 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,23 +1,28 @@ mod context_loader; mod params_loader; mod types; +mod types_ext; +mod writer; use std::collections::HashMap; -use std::fs::OpenOptions; -use std::io::Write; -use context_loader::PayloadErrorKind; -use params_loader::ParamsError; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use serde_json::json; +use serde_json::Map; use serde_json::Value; use thiserror::Error; use crate::context_loader::DefaultLoader as PipesDefaultContextLoader; pub use crate::context_loader::LoadContext; +use crate::context_loader::PayloadErrorKind; use crate::params_loader::EnvVarLoader as PipesEnvVarParamsLoader; pub use crate::params_loader::LoadParams; +use crate::params_loader::ParamsError; pub use crate::types::{Method, PipesContextData, PipesMessage}; +use crate::writer::message_writer::get_opened_payload; +use crate::writer::message_writer::DefaultWriter as PipesDefaultMessageWriter; +pub use crate::writer::message_writer::MessageWriter; +pub use crate::writer::message_writer_channel::MessageWriterChannel; #[derive(Serialize)] #[serde(rename_all = "UPPERCASE")] @@ -29,12 +34,38 @@ pub enum AssetCheckSeverity { // partial translation of // https://github.com/dagster-io/dagster/blob/258d9ca0db/python_modules/dagster-pipes/dagster_pipes/__init__.py#L859-L871 #[derive(Debug)] -pub struct PipesContext { +pub struct PipesContext +where + W: MessageWriter, +{ data: PipesContextData, - writer: PipesFileMessageWriter, + message_channel: W::Channel, } -impl PipesContext { +impl PipesContext +where + W: MessageWriter, +{ + pub fn new( + context_data: PipesContextData, + message_params: Map, + message_writer: &W, + ) -> Self { + let mut message_channel = message_writer.open(message_params); + let opened_payload = get_opened_payload(message_writer); + let opened_message = PipesMessage { + dagster_pipes_version: "0.1".to_string(), // TODO: Convert to `const` + method: Method::Opened, + params: Some(opened_payload), + }; + message_channel.write_message(opened_message); + + Self { + data: context_data, + message_channel, + } + } + pub fn report_asset_materialization(&mut self, asset_key: &str, metadata: serde_json::Value) { let params: HashMap> = HashMap::from([ ("asset_key".to_string(), Some(json!(asset_key))), @@ -42,12 +73,8 @@ impl PipesContext { ("data_version".to_string(), None), // TODO - support data versions ]); - let msg = PipesMessage { - dagster_pipes_version: "0.1".to_string(), - method: Method::ReportAssetMaterialization, - params: Some(params), - }; - self.writer.write_message(&msg); + let msg = PipesMessage::new(Method::ReportAssetMaterialization, Some(params)); + self.message_channel.write_message(msg); } pub fn report_asset_check( @@ -66,36 +93,11 @@ impl PipesContext { ("metadata".to_string(), Some(metadata)), ]); - let msg = PipesMessage { - dagster_pipes_version: "0.1".to_string(), - method: Method::ReportAssetCheck, - params: Some(params), - }; - self.writer.write_message(&msg); - } -} - -#[derive(Debug)] -struct PipesFileMessageWriter { - path: String, -} -impl PipesFileMessageWriter { - fn write_message(&mut self, message: &PipesMessage) { - let serialized_msg = serde_json::to_string(&message).unwrap(); - let mut file = OpenOptions::new().append(true).open(&self.path).unwrap(); - writeln!(file, "{serialized_msg}").unwrap(); - - // TODO - optional `stderr` based writing - //eprintln!("{}", serialized_msg); + let msg = PipesMessage::new(Method::ReportAssetCheck, Some(params)); + self.message_channel.write_message(msg); } } -#[derive(Debug, Deserialize)] -struct PipesMessagesParams { - path: Option, // write to file - stdio: Option, // stderr | stdout (unsupported) -} - #[derive(Debug, Error)] #[non_exhaustive] pub enum DagsterPipesError { @@ -111,26 +113,19 @@ pub enum DagsterPipesError { // partial translation of // https://github.com/dagster-io/dagster/blob/258d9ca0db/python_modules/dagster-pipes/dagster_pipes/__init__.py#L798-L838 #[must_use] -pub fn open_dagster_pipes() -> Result { +pub fn open_dagster_pipes() -> Result, DagsterPipesError> { let params_loader = PipesEnvVarParamsLoader::new(); let context_loader = PipesDefaultContextLoader::new(); + let message_writer = PipesDefaultMessageWriter::new(); let context_params = params_loader.load_context_params()?; + let message_params = params_loader.load_message_params()?; + let context_data = context_loader.load_context(context_params)?; - let message_params = params_loader.load_message_params()?; - // TODO: Refactor into MessageWriter impl - let path = match &message_params["path"] { - Value::String(string) => string.clone(), - _ => panic!("Expected message \"path\" in bootstrap payload"), - }; - - //if stdio != "stderr" { - // panic!("only stderr supported for dagster pipes messages") - //} - - Ok(PipesContext { - data: context_data, - writer: PipesFileMessageWriter { path }, - }) + Ok(PipesContext::new( + context_data, + message_params, + &message_writer, + )) } diff --git a/src/params_loader.rs b/src/params_loader.rs index d6d48ac..9c5c86a 100644 --- a/src/params_loader.rs +++ b/src/params_loader.rs @@ -11,7 +11,7 @@ const DAGSTER_PIPES_MESSAGES_ENV_VAR: &str = "DAGSTER_PIPES_MESSAGES"; /// Load params passed from the orchestration process by the context injector and /// message reader. These params are used to respectively bootstrap the -/// [`PipesContextLoader`] and [`PipesMessageWriter`]. +/// [`PipesContextLoader`] and [`MessageWriter`](crate::MessageWriter). pub trait LoadParams { /// Whether or not this process has been provided with provided with information /// to create a `PipesContext` or should instead return a mock. diff --git a/src/types.rs b/src/types.rs index a4228a3..1103b58 100644 --- a/src/types.rs +++ b/src/types.rs @@ -16,7 +16,7 @@ use std::collections::HashMap; /// The serializable data passed from the orchestration process to the external process. This /// gets wrapped in a PipesContext. -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct PipesContextData { pub asset_keys: Option>, @@ -39,21 +39,21 @@ pub struct PipesContextData { pub run_id: String, } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct PartitionKeyRange { pub end: Option, pub start: Option, } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct PartitionTimeWindow { pub end: Option, pub start: Option, } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ProvenanceByAssetKey { pub code_version: Option, @@ -62,7 +62,7 @@ pub struct ProvenanceByAssetKey { pub is_user_provided: Option, } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct PipesException { /// exception that explicitly led to this exception pub cause: Box>, @@ -79,7 +79,7 @@ pub struct PipesException { } /// exception that being handled when this exception was raised -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ContextClass { /// exception that explicitly led to this exception pub cause: Box>, @@ -95,7 +95,7 @@ pub struct ContextClass { pub stack: Option>, } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct PipesExceptionClass { /// exception that explicitly led to this exception pub cause: Box>, @@ -111,7 +111,7 @@ pub struct PipesExceptionClass { pub stack: Option>, } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct PipesMessage { /// The version of the Dagster Pipes protocol #[serde(rename = "__dagster_pipes_version")] @@ -125,7 +125,7 @@ pub struct PipesMessage { } /// Event type -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum Method { Closed, @@ -144,7 +144,7 @@ pub enum Method { ReportCustomMessage, } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct PipesMetadataValue { pub raw_value: Option, @@ -152,7 +152,7 @@ pub struct PipesMetadataValue { pub pipes_metadata_value_type: Option, } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum Type { Asset, @@ -188,7 +188,7 @@ pub enum Type { Url, } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(untagged)] pub enum RawValue { AnythingArray(Vec>), diff --git a/src/types_ext.rs b/src/types_ext.rs new file mode 100644 index 0000000..f047d31 --- /dev/null +++ b/src/types_ext.rs @@ -0,0 +1,13 @@ +use std::collections::HashMap; + +use crate::{Method, PipesMessage}; + +impl PipesMessage { + pub fn new(method: Method, params: Option>>) -> Self { + Self { + dagster_pipes_version: "0.1".to_string(), // TODO: Make `const` + method, + params, + } + } +} diff --git a/src/writer.rs b/src/writer.rs new file mode 100644 index 0000000..bbd90bf --- /dev/null +++ b/src/writer.rs @@ -0,0 +1,13 @@ +pub mod message_writer; +pub mod message_writer_channel; + +use serde::Serialize; + +// TODO: Might need to rename since `In` is not supported +#[derive(Debug, Clone, PartialEq, Serialize)] +pub enum StdStream { + #[serde(rename = "stdout")] + Out, + #[serde(rename = "stderr")] + Err, +} diff --git a/src/writer/message_writer.rs b/src/writer/message_writer.rs new file mode 100644 index 0000000..4ccadf7 --- /dev/null +++ b/src/writer/message_writer.rs @@ -0,0 +1,208 @@ +use std::collections::HashMap; + +use serde_json::{Map, Value}; + +use crate::writer::message_writer_channel::{ + BufferedStreamChannel, DefaultChannel, FileChannel, MessageWriterChannel, StreamChannel, +}; +use crate::writer::StdStream; + +mod private { + pub struct Token; // To seal certain trait methods +} + +/// Write messages back to Dagster, via its associated [`Self::Channel`]. +pub trait MessageWriter { + type Channel: MessageWriterChannel; + + /// Initialize a channel for writing messages back to Dagster. + /// + /// This method should takes the params passed by the orchestration-side + /// `PipesMessageReader` and use them to construct and yield + /// [`MessageWriterChannel`]. + fn open(&self, params: Map) -> Self::Channel; + + /// Return a payload containing information about the external process to be passed back to + /// the orchestration process. This should contain information that cannot be known before + /// the external process is launched. + /// + /// # Note + /// This method is sealed — it should not be overridden by users. + /// Instead, users should override [`Self::get_opened_extras`] to inject custom data. + /// + /// ```compile_fail + /// # use serde_json::{Map, Value}; + /// # use dagster_pipes_rust::{MessageWriter, MessageWriterChannel, PipesMessage}; + /// # + /// struct MyMessageWriter(u64); + /// # + /// # struct MyChannel; + /// # impl MessageWriterChannel for MyChannel { + /// # fn write_message(&mut self, message: PipesMessage) { + /// # todo!() + /// # } + /// # }; + /// + /// impl MessageWriter for MyMessageWriter { + /// # type Channel = MyChannel; + /// # + /// # fn open(&self, params: Map) -> Self::Channel { + /// // ... + /// # todo!() + /// # } + /// } + /// + /// MyMessageWriter(42).get_opened_payload(private::Token); // use of undeclared crate or module `private` + /// ``` + fn get_opened_payload(&self, _: private::Token) -> HashMap> { + let mut extras = HashMap::new(); + extras.insert( + "extras".to_string(), + Some(Value::Object(self.get_opened_extras())), + ); + extras + } + + /// Return arbitary reader-specific information to be passed back to the orchestration + /// process. The information will be returned under the `extras` key of the initialization payload. + fn get_opened_extras(&self) -> Map { + Map::new() + } +} + +/// Public accessor to the sealed method +pub fn get_opened_payload(writer: &impl MessageWriter) -> HashMap> { + writer.get_opened_payload(private::Token) +} + +pub struct DefaultWriter; + +impl DefaultWriter { + pub fn new() -> Self { + Self + } +} + +impl Default for DefaultWriter { + fn default() -> Self { + Self::new() + } +} + +impl MessageWriter for DefaultWriter { + type Channel = DefaultChannel; + + fn open(&self, params: Map) -> Self::Channel { + const FILE_PATH_KEY: &str = "path"; + const STDIO_KEY: &str = "stdio"; + const BUFFERED_STDIO_KEY: &str = "buffered_stdio"; + const STDERR: &str = "stderr"; + const STDOUT: &str = "stdout"; + const INCLUDE_STDIO_IN_MESSAGES_KEY: &str = "include_stdio_in_messages"; + + match ( + params.get(FILE_PATH_KEY), + params.get(STDIO_KEY), + params.get(BUFFERED_STDIO_KEY), + ) { + (Some(Value::String(path)), _, _) => { + // TODO: This is a simplified implementation. Utilize `PipesLogWriter` + DefaultChannel::File(FileChannel::new(path.into())) + } + (None, Some(Value::String(stream)), _) => match &*(stream.to_lowercase()) { + STDOUT => DefaultChannel::Stream(StreamChannel::new(StdStream::Out)), + STDERR => DefaultChannel::Stream(StreamChannel::new(StdStream::Err)), + _ => panic!("Invalid stream provided for stdio writer channel"), + }, + (None, None, Some(Value::String(stream))) => { + // Once `PipesBufferedStreamMessageWriterChannel` is dropped, the buffered data is written + match &*(stream.to_lowercase()) { + STDOUT => { + DefaultChannel::BufferedStream(BufferedStreamChannel::new(StdStream::Out)) + } + STDERR => { + DefaultChannel::BufferedStream(BufferedStreamChannel::new(StdStream::Err)) + } + _ => panic!("Invalid stream provided for buffered stdio writer channel"), + } + } + _ => panic!("No way to write messages"), + } + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use serde_json::json; + + use super::*; + use crate::writer::message_writer_channel::{ + BufferedStreamChannel, FileChannel, StreamChannel, + }; + + #[test] + fn test_open_with_file_path_key() { + let writer = DefaultWriter; + let params = serde_json::from_str(r#"{"path": "my-file-path"}"#) + .expect("Failed to parse raw JSON string"); + assert_eq!( + writer.open(params), + DefaultChannel::File(FileChannel::new("my-file-path".into())) + ); + } + + #[rstest] + #[case("stdout", StdStream::Out)] + #[case("stderr", StdStream::Err)] + fn test_open_with_stdio_key(#[case] value: &str, #[case] stream: StdStream) { + let writer = DefaultWriter; + let Value::Object(params) = json!({"stdio": value}) else { + panic!("Unexpected JSON type encountered") + }; + assert_eq!( + writer.open(params), + DefaultChannel::Stream(StreamChannel::new(stream)) + ); + } + + #[rstest] + #[case("stdout", StdStream::Out)] + #[case("stderr", StdStream::Err)] + fn test_open_with_buffered_stdio_key(#[case] value: &str, #[case] stream: StdStream) { + let writer = DefaultWriter; + let Value::Object(params) = json!({"buffered_stdio": value}) else { + panic!("Unexpected JSON type encountered") + }; + assert_eq!( + writer.open(params), + DefaultChannel::BufferedStream(BufferedStreamChannel::new(stream)) + ); + } + + #[test] + fn test_open_prioritizes_file_path_over_everything_else() { + let writer = DefaultWriter; + let Value::Object(params) = + json!({"path": "my-file-path", "stdio": "stdout", "buffered_stdio": "stderr"}) + else { + panic!("Unexpected JSON type encountered") + }; + assert_eq!( + writer.open(params), + DefaultChannel::File(FileChannel::new("my-file-path".into())) + ); + } + + #[test] + fn test_open_prioritizes_stream_over_buffered_stream() { + let writer = DefaultWriter; + let Value::Object(params) = json!({"stdio": "stdout", "buffered_stdio": "stderr"}) else { + panic!("Unexpected JSON type encountered") + }; + assert_eq!( + writer.open(params), + DefaultChannel::Stream(StreamChannel::new(StdStream::Out)) + ); + } +} diff --git a/src/writer/message_writer_channel.rs b/src/writer/message_writer_channel.rs new file mode 100644 index 0000000..0f3f3bb --- /dev/null +++ b/src/writer/message_writer_channel.rs @@ -0,0 +1,152 @@ +use std::{ffi::OsString, fs::OpenOptions, io::Write}; + +use crate::types::PipesMessage; + +use super::StdStream; + +/// Write messages back to the Dagster orchestration process. +/// To be used in conjunction with [`MessageWriter`](crate::MessageWriter). +pub trait MessageWriterChannel { + /// Write a message to the orchestration process + fn write_message(&mut self, message: PipesMessage); +} + +#[derive(Debug, PartialEq)] +pub struct FileChannel { + path: OsString, +} + +impl FileChannel { + pub fn new(path: OsString) -> Self { + Self { path } + } +} + +impl MessageWriterChannel for FileChannel { + fn write_message(&mut self, message: PipesMessage) { + let mut file = OpenOptions::new().append(true).open(&self.path).unwrap(); + let json = serde_json::to_string(&message).unwrap(); + writeln!(file, "{json}").unwrap(); + } +} + +#[derive(Debug, PartialEq)] +pub struct StreamChannel { + stream: StdStream, +} + +impl StreamChannel { + pub fn new(stream: StdStream) -> Self { + Self { stream } + } + + fn _format_message(message: &PipesMessage) -> Vec { + format!("{}\n", serde_json::to_string(message).unwrap()).into_bytes() + } +} + +impl MessageWriterChannel for StreamChannel { + fn write_message(&mut self, message: PipesMessage) { + match self.stream { + StdStream::Out => std::io::stdout() + .write_all(&Self::_format_message(&message)) + .unwrap(), + StdStream::Err => std::io::stderr() + .write_all(&Self::_format_message(&message)) + .unwrap(), + } + } +} + +#[derive(Debug, PartialEq)] +pub struct BufferedStreamChannel { + buffer: Vec, + stream: StdStream, +} + +impl BufferedStreamChannel { + pub fn new(stream: StdStream) -> Self { + Self { + buffer: vec![], + stream, + } + } + + /// Flush messages in the buffer to the stream + ///
This class will called once on `Drop`
+ fn flush(&mut self) { + let _: Vec<_> = self + .buffer + .iter() + .map(|msg| match self.stream { + StdStream::Out => std::io::stdout() + .write(&Self::_format_message(msg)) + .unwrap(), + StdStream::Err => std::io::stderr() + .write(&Self::_format_message(msg)) + .unwrap(), + }) + .collect(); + self.buffer.clear(); + } + + fn _format_message(message: &PipesMessage) -> Vec { + format!("{}\n", serde_json::to_string(message).unwrap()).into_bytes() + } +} + +impl Drop for BufferedStreamChannel { + /// Flush the data when out of scope or panicked. + ///
Panic aborting will prevent `Drop` and this function from running
+ fn drop(&mut self) { + self.flush(); + } +} + +impl MessageWriterChannel for BufferedStreamChannel { + fn write_message(&mut self, message: PipesMessage) { + self.buffer.push(message); + } +} + +#[derive(Debug, PartialEq)] +#[non_exhaustive] +pub enum DefaultChannel { + File(FileChannel), + Stream(StreamChannel), + BufferedStream(BufferedStreamChannel), +} + +impl MessageWriterChannel for DefaultChannel { + fn write_message(&mut self, message: PipesMessage) { + match self { + Self::File(channel) => channel.write_message(message), + Self::Stream(channel) => channel.write_message(message), + Self::BufferedStream(channel) => channel.write_message(message), + } + } +} + +#[cfg(test)] +mod tests_file_channel { + use tempfile::NamedTempFile; + + use crate::{Method, PipesMessage}; + + use super::{FileChannel, MessageWriterChannel}; + + #[test] + fn test_write_message() { + let file = NamedTempFile::new().expect("Failed to create tempfile for testing"); + let mut channel = FileChannel::new(file.path().into()); + let message = PipesMessage::new(Method::Opened, None); + channel.write_message(message.clone()); + + let file_content = + std::fs::read_to_string(file.path()).expect("Failed to read from tempfile"); + assert_eq!( + message, + serde_json::from_str(&file_content).expect("Failed to serialize PipesMessage") + ); + } +}