From ea8e26e8596dd2267c537c7ad74395f1eda63a4c Mon Sep 17 00:00:00 2001 From: gracig <15052330+gracig@users.noreply.github.com> Date: Mon, 12 Aug 2024 13:46:08 -0300 Subject: [PATCH] Create a prelude for the fluent_sdk --- crates/fluent-sdk/src/lib.rs | 154 +++------------------------ crates/fluent-sdk/src/openai-chat.rs | 0 crates/fluent-sdk/src/openai.rs | 141 ++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 142 deletions(-) delete mode 100644 crates/fluent-sdk/src/openai-chat.rs create mode 100644 crates/fluent-sdk/src/openai.rs diff --git a/crates/fluent-sdk/src/lib.rs b/crates/fluent-sdk/src/lib.rs index 13dec95..aa84fce 100644 --- a/crates/fluent-sdk/src/lib.rs +++ b/crates/fluent-sdk/src/lib.rs @@ -1,9 +1,19 @@ use anyhow::anyhow; use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; +use serde_json::Value; use std::collections::HashMap; use strum::{Display, EnumString}; +pub mod openai; +pub mod prelude { + pub use crate::openai::*; + pub use crate::{FluentRequest, FluentSdkRequest, KeyValue}; +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Response { + pub data: fluent_core::types::Response, +} #[async_trait::async_trait] pub trait FluentSdkRequest: Into + Clone { fn as_request(&self) -> FluentRequest { @@ -30,7 +40,7 @@ pub struct FluentRequest { impl FluentRequest { pub async fn run(&self) -> anyhow::Result { // Convert the implementing type into a FluentRequest - let request: FluentRequest = self.clone(); + let request = self.clone(); // Perform the run logic that was previously in the `run` function let engine_name = request .engine @@ -230,143 +240,3 @@ pub struct OverrideValue { pub key: String, pub value: Value, } - -#[derive(Debug, Deserialize, Serialize)] -pub struct Response { - pub data: fluent_core::types::Response, -} - -#[derive(Debug, Deserialize, Serialize, Clone)] -pub struct FluentOpenAIChatRequest { - pub prompt: String, - pub openai_key: String, - pub model: Option, - pub response_format: Option, - pub temperature: Option, - pub max_tokens: Option, - pub top_p: Option, - pub n: Option, - pub stop: Option>, - pub frequency_penalty: Option, - pub presence_penalty: Option, -} -impl From for FluentRequest { - fn from(request: FluentOpenAIChatRequest) -> Self { - let mut overrides = vec![]; - if let Some(response_format) = request.response_format { - overrides.push(("response_format".to_string(), response_format)); - } - if let Some(temperature) = request.temperature { - overrides.push(("temperature".to_string(), json!(temperature))); - } - if let Some(max_tokens) = request.max_tokens { - overrides.push(("max_tokens".to_string(), json!(max_tokens))); - } - if let Some(top_p) = request.top_p { - overrides.push(("top_p".to_string(), json!(top_p))); - } - if let Some(frequency_penalty) = request.frequency_penalty { - overrides.push(("frequency_penalty".to_string(), json!(frequency_penalty))); - } - if let Some(presence_penalty) = request.presence_penalty { - overrides.push(("presence_penalty".to_string(), json!(presence_penalty))); - } - if let Some(model_name) = request.model { - overrides.push(("modelName".to_string(), json!(model_name))); - } - if let Some(n) = request.n { - overrides.push(("n".to_string(), json!(n))); - } - if let Some(stop) = request.stop { - overrides.push(("stop".to_string(), json!(stop))); - } - FluentRequest { - request: Some(request.prompt), - engine: Some(EngineTemplate::OpenAIChatCompletions), - credentials: Some(vec![KeyValue::new("OPENAI_API_KEY", &request.openai_key)]), - overrides: Some(overrides.into_iter().collect()), - parse_code: None, - } - } -} -impl FluentSdkRequest for FluentOpenAIChatRequest {} - -#[derive(Debug, Deserialize, Serialize, Clone)] -pub struct FluentOpenAIChatRequestBuilder { - request: FluentOpenAIChatRequest, -} -impl Default for FluentOpenAIChatRequestBuilder { - fn default() -> Self { - Self { - request: FluentOpenAIChatRequest { - prompt: String::new(), - openai_key: String::new(), - response_format: None, - temperature: None, - max_tokens: None, - top_p: None, - frequency_penalty: None, - presence_penalty: None, - model: None, - n: None, - stop: None, - }, - } - } -} - -impl FluentOpenAIChatRequestBuilder { - pub fn prompt(mut self, prompt: String) -> Self { - self.request.prompt = prompt; - self - } - pub fn openai_key(mut self, openai_key: String) -> Self { - self.request.openai_key = openai_key; - self - } - pub fn response_format(mut self, response_format: Value) -> Self { - self.request.response_format = Some(response_format); - self - } - pub fn temperature(mut self, temperature: f64) -> Self { - self.request.temperature = Some(temperature); - self - } - pub fn max_tokens(mut self, max_tokens: i64) -> Self { - self.request.max_tokens = Some(max_tokens); - self - } - pub fn top_p(mut self, top_p: f64) -> Self { - self.request.top_p = Some(top_p); - self - } - pub fn frequency_penalty(mut self, frequency_penalty: f64) -> Self { - self.request.frequency_penalty = Some(frequency_penalty); - self - } - pub fn presence_penalty(mut self, presence_penalty: f64) -> Self { - self.request.presence_penalty = Some(presence_penalty); - self - } - pub fn model(mut self, model: String) -> Self { - self.request.model = Some(model); - self - } - pub fn n(mut self, n: i8) -> Self { - self.request.n = Some(n); - self - } - pub fn stop(mut self, stop: Vec) -> Self { - self.request.stop = Some(stop); - self - } - pub fn build(self) -> anyhow::Result { - if self.request.prompt.is_empty() { - return Err(anyhow!("Prompt is required")); - } - if self.request.openai_key.is_empty() { - return Err(anyhow!("OpenAI key is required")); - } - Ok(self.request) - } -} diff --git a/crates/fluent-sdk/src/openai-chat.rs b/crates/fluent-sdk/src/openai-chat.rs deleted file mode 100644 index e69de29..0000000 diff --git a/crates/fluent-sdk/src/openai.rs b/crates/fluent-sdk/src/openai.rs new file mode 100644 index 0000000..55fb9c3 --- /dev/null +++ b/crates/fluent-sdk/src/openai.rs @@ -0,0 +1,141 @@ +use anyhow::anyhow; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +use crate::{EngineTemplate, FluentRequest, FluentSdkRequest, KeyValue}; + +impl FluentSdkRequest for FluentOpenAIChatRequest {} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct FluentOpenAIChatRequest { + pub prompt: String, + pub openai_key: String, + pub model: Option, + pub response_format: Option, + pub temperature: Option, + pub max_tokens: Option, + pub top_p: Option, + pub n: Option, + pub stop: Option>, + pub frequency_penalty: Option, + pub presence_penalty: Option, +} +impl From for FluentRequest { + fn from(request: FluentOpenAIChatRequest) -> Self { + let mut overrides = vec![]; + if let Some(response_format) = request.response_format { + overrides.push(("response_format".to_string(), response_format)); + } + if let Some(temperature) = request.temperature { + overrides.push(("temperature".to_string(), json!(temperature))); + } + if let Some(max_tokens) = request.max_tokens { + overrides.push(("max_tokens".to_string(), json!(max_tokens))); + } + if let Some(top_p) = request.top_p { + overrides.push(("top_p".to_string(), json!(top_p))); + } + if let Some(frequency_penalty) = request.frequency_penalty { + overrides.push(("frequency_penalty".to_string(), json!(frequency_penalty))); + } + if let Some(presence_penalty) = request.presence_penalty { + overrides.push(("presence_penalty".to_string(), json!(presence_penalty))); + } + if let Some(model_name) = request.model { + overrides.push(("modelName".to_string(), json!(model_name))); + } + if let Some(n) = request.n { + overrides.push(("n".to_string(), json!(n))); + } + if let Some(stop) = request.stop { + overrides.push(("stop".to_string(), json!(stop))); + } + FluentRequest { + request: Some(request.prompt), + engine: Some(EngineTemplate::OpenAIChatCompletions), + credentials: Some(vec![KeyValue::new("OPENAI_API_KEY", &request.openai_key)]), + overrides: Some(overrides.into_iter().collect()), + parse_code: None, + } + } +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct FluentOpenAIChatRequestBuilder { + request: FluentOpenAIChatRequest, +} +impl Default for FluentOpenAIChatRequestBuilder { + fn default() -> Self { + Self { + request: FluentOpenAIChatRequest { + prompt: String::new(), + openai_key: String::new(), + response_format: None, + temperature: None, + max_tokens: None, + top_p: None, + frequency_penalty: None, + presence_penalty: None, + model: None, + n: None, + stop: None, + }, + } + } +} + +impl FluentOpenAIChatRequestBuilder { + pub fn prompt(mut self, prompt: String) -> Self { + self.request.prompt = prompt; + self + } + pub fn openai_key(mut self, openai_key: String) -> Self { + self.request.openai_key = openai_key; + self + } + pub fn response_format(mut self, response_format: Value) -> Self { + self.request.response_format = Some(response_format); + self + } + pub fn temperature(mut self, temperature: f64) -> Self { + self.request.temperature = Some(temperature); + self + } + pub fn max_tokens(mut self, max_tokens: i64) -> Self { + self.request.max_tokens = Some(max_tokens); + self + } + pub fn top_p(mut self, top_p: f64) -> Self { + self.request.top_p = Some(top_p); + self + } + pub fn frequency_penalty(mut self, frequency_penalty: f64) -> Self { + self.request.frequency_penalty = Some(frequency_penalty); + self + } + pub fn presence_penalty(mut self, presence_penalty: f64) -> Self { + self.request.presence_penalty = Some(presence_penalty); + self + } + pub fn model(mut self, model: String) -> Self { + self.request.model = Some(model); + self + } + pub fn n(mut self, n: i8) -> Self { + self.request.n = Some(n); + self + } + pub fn stop(mut self, stop: Vec) -> Self { + self.request.stop = Some(stop); + self + } + pub fn build(self) -> anyhow::Result { + if self.request.prompt.is_empty() { + return Err(anyhow!("Prompt is required")); + } + if self.request.openai_key.is_empty() { + return Err(anyhow!("OpenAI key is required")); + } + Ok(self.request) + } +}