From 96fde1b4b3477c5e3c590effe8d10f7fa0e5a023 Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Mon, 28 Oct 2024 12:32:52 -0400 Subject: [PATCH 01/13] Added Bedrock UI elements, credential storage and a few other things --- Cargo.toml | 7 +- crates/assistant/src/assistant_settings.rs | 1 + crates/bedrock/Cargo.toml | 33 + crates/bedrock/src/bedrock.rs | 334 +++++++++ crates/language_model/Cargo.toml | 4 + crates/language_model/src/provider.rs | 1 + crates/language_model/src/provider/bedrock.rs | 645 ++++++++++++++++++ crates/language_model/src/registry.rs | 6 + crates/language_model/src/request.rs | 9 + crates/language_model/src/settings.rs | 11 + 10 files changed, 1049 insertions(+), 2 deletions(-) create mode 100644 crates/bedrock/Cargo.toml create mode 100644 crates/bedrock/src/bedrock.rs create mode 100644 crates/language_model/src/provider/bedrock.rs diff --git a/Cargo.toml b/Cargo.toml index c6d182fd5808a..5b6e6688dab88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "crates/assistant_tool", "crates/audio", "crates/auto_update", + "crates/bedrock", "crates/breadcrumbs", "crates/call", "crates/channel", @@ -166,8 +167,6 @@ members = [ # # Tooling # - - "tooling/xtask", ] default-members = ["crates/zed"] @@ -186,6 +185,7 @@ assistant_slash_command = { path = "crates/assistant_slash_command" } assistant_tool = { path = "crates/assistant_tool" } audio = { path = "crates/audio" } auto_update = { path = "crates/auto_update" } +bedrock = { path = "crates/bedrock" } breadcrumbs = { path = "crates/breadcrumbs" } call = { path = "crates/call" } channel = { path = "crates/channel" } @@ -326,6 +326,9 @@ async-trait = "0.1" async-tungstenite = "0.28" async-watch = "0.3.1" async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] } +aws-credential-types = { version = "1.2.1", features = ["hardcoded-credentials"] } +aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } +aws-sdk-bedrockruntime = { version = "1.57.0", features = ["behavior-version-latest"]} base64 = "0.22" bitflags = "2.6.0" blade-graphics = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" } diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 6bc0c21ea5d11..5a2d234f6f7ae 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -430,6 +430,7 @@ pub struct LanguageModelSelection { fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { schemars::schema::SchemaObject { enum_values: Some(vec![ + "bedrock".into(), "anthropic".into(), "google".into(), "ollama".into(), diff --git a/crates/bedrock/Cargo.toml b/crates/bedrock/Cargo.toml new file mode 100644 index 0000000000000..ae13bf09a090b --- /dev/null +++ b/crates/bedrock/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "bedrock" +version = "0.1.0" +edition = "2021" +publish = false +license = "AGPL-3.0-or-later" + +[features] +default = [] +schemars = ["dep:schemars"] + +[lints] +workspace = true + +[lib] +path = "src/bedrock.rs" + +[dependencies] +anyhow.workspace = true +chrono.workspace = true +futures.workspace = true +http_client.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true +strum.workspace = true +thiserror.workspace = true +util.workspace = true +aws-sdk-bedrockruntime = { workspace = true, features = ["behavior-version-latest"]} +aws-config = {workspace = true, features = ["behavior-version-latest"]} + +[dev-dependencies] +tokio.workspace = true diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs new file mode 100644 index 0000000000000..baba056760aae --- /dev/null +++ b/crates/bedrock/src/bedrock.rs @@ -0,0 +1,334 @@ +use std::time::Duration; +use std::{pin::Pin, str::FromStr}; + +use anyhow::{anyhow, Context, Result}; +use aws_sdk_bedrockruntime as bedrock; +use chrono::{DateTime, Utc}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; +use serde::{Deserialize, Serialize}; +use strum::{EnumIter, EnumString}; +use thiserror::Error; +use util::ResultExt as _; + + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct BedrockModelCacheConfiguration { + pub min_total_token: usize, + pub should_speculate: bool, + pub max_cache_anchors: usize, +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] +pub enum Model { + #[default] + #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")] + Claude3_5Sonnet, + #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")] + Claude3Opus, + #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")] + Claude3Sonnet, + #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-latest")] + Claude3Haiku, + #[serde(rename = "custom")] + Custom { + name: String, + max_tokens: usize, + /// The name displayed in the UI, such as in the assistant panel model dropdown menu. + display_name: Option, + /// Override this model with a different Bedrock model for tool calls. + tool_override: Option, + /// Indicates whether this custom model supports caching. + cache_configuration: Option, + max_output_tokens: Option, + default_temperature: Option, + }, +} + +impl Model { + pub fn from_id(id: &str) -> Result { + if id.starts_with("claude-3-5-sonnet") { + Ok(Self::Claude3_5Sonnet) + } else if id.starts_with("claude-3-opus") { + Ok(Self::Claude3Opus) + } else if id.starts_with("claude-3-sonnet") { + Ok(Self::Claude3Sonnet) + } else if id.starts_with("claude-3-haiku") { + Ok(Self::Claude3Haiku) + } else { + Err(anyhow!("invalid model id")) + } + } + + pub fn id(&self) -> &str { + match self { + Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest", + Model::Claude3Opus => "claude-3-opus-latest", + Model::Claude3Sonnet => "claude-3-sonnet-latest", + Model::Claude3Haiku => "claude-3-haiku-latest", + Self::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", + Self::Claude3Opus => "Claude 3 Opus", + Self::Claude3Sonnet => "Claude 3 Sonnet", + Self::Claude3Haiku => "Claude 3 Haiku", + Self::Custom { + name, display_name, .. + } => display_name.as_ref().unwrap_or(name), + } + } + + pub fn cache_configuration(&self) -> Option { + match self { + Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(BedrockModelCacheConfiguration { + min_total_token: 2_048, + should_speculate: true, + max_cache_anchors: 4, + }), + Self::Custom { + cache_configuration, + .. + } => cache_configuration.clone(), + _ => None, + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Self::Claude3_5Sonnet + | Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3Haiku => 200_000, + Self::Custom { max_tokens, .. } => *max_tokens, + } + } + + pub fn max_output_tokens(&self) -> u32 { + match self { + Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096, + Self::Claude3_5Sonnet => 8_192, + Self::Custom { + max_output_tokens, .. + } => max_output_tokens.unwrap_or(4_096), + } + } + + pub fn default_temperature(&self) -> f32 { + match self { + Self::Claude3_5Sonnet + | Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3Haiku => 1.0, + Self::Custom { + default_temperature, + .. + } => default_temperature.unwrap_or(1.0), + } + } + + pub fn tool_model_id(&self) -> &str { + if let Self::Custom { + tool_override: Some(tool_override), + .. + } = self + { + tool_override + } else { + self.id() + } + } +} + +pub async fn complete( + client: &bedrock::Client, + api_url: &str, + api_key: &str, + request: Request, +) -> Result { + todo!() +} + +pub async fn stream_completion( + client: &bedrock::Client, + api_url: &str, + api_key: &str, + request: Request, + low_speed_timeout: Option, +) -> Result>, BedrockError> { + todo!() +} + +#[derive(Debug, Serialize, Deserialize, Copy, Clone)] +#[serde(rename_all = "lowercase")] +pub enum CacheControlType { + Ephemeral, +} + +#[derive(Debug, Serialize, Deserialize, Copy, Clone)] +pub struct CacheControl { + #[serde(rename = "type")] + pub cache_type: CacheControlType, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub content: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum RequestContent { + #[serde(rename = "text")] + Text { + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, + #[serde(rename = "image")] + Image { + source: ImageSource, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ResponseContent { + #[serde(rename = "text")] + Text { text: String }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ImageSource { + #[serde(rename = "type")] + pub source_type: String, + pub media_type: String, + pub data: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Request { + pub model: String, + pub max_tokens: u32, + pub messages: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub system: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub stop_sequences: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub top_p: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct StreamingRequest { + #[serde(flatten)] + pub base: Request, + pub stream: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Metadata { + pub user_id: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Usage { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cache_creation_input_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cache_read_input_tokens: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Response { + pub id: String, + #[serde(rename = "type")] + pub response_type: String, + pub role: Role, + pub content: Vec, + pub model: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub stop_reason: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub stop_sequence: Option, + pub usage: Usage, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum Event { + #[serde(rename = "message_start")] + MessageStart { message: Response }, + #[serde(rename = "content_block_start")] + ContentBlockStart { + index: usize, + content_block: ResponseContent, + }, + #[serde(rename = "content_block_delta")] + ContentBlockDelta { index: usize, delta: ContentDelta }, + #[serde(rename = "content_block_stop")] + ContentBlockStop { index: usize }, + #[serde(rename = "message_delta")] + MessageDelta { delta: MessageDelta, usage: Usage }, + #[serde(rename = "message_stop")] + MessageStop, + #[serde(rename = "ping")] + Ping, + #[serde(rename = "error")] + Error { error: ApiError }, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ContentDelta { + #[serde(rename = "text_delta")] + TextDelta { text: String }, + #[serde(rename = "input_json_delta")] + InputJsonDelta { partial_json: String }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MessageDelta { + pub stop_reason: Option, + pub stop_sequence: Option, +} + +#[derive(Error, Debug)] +pub enum BedrockError { + // TODO: propagate the error message + #[error("an error occurred while interacting with the Bedrock API")] + ApiError(bedrock::Error), + #[error("{0}")] + Other(#[from] anyhow::Error), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ApiError { + #[serde(rename = "type")] + pub error_type: String, + pub message: String, +} diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index e88675bbae3a9..d4ee3e018ccfb 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -21,7 +21,11 @@ test-support = [ ] [dependencies] +bedrock = { workspace = true, features = ["schemars"] } anthropic = { workspace = true, features = ["schemars"] } +aws-credential-types = { workspace = true, features = ["hardcoded-credentials"] } +aws-config = { workspace = true, features = ["behavior-version-latest"]} +aws-sdk-bedrockruntime = { workspace = true, features = ["behavior-version-latest"] } anyhow.workspace = true client.workspace = true collections.workspace = true diff --git a/crates/language_model/src/provider.rs b/crates/language_model/src/provider.rs index d2d162b75e056..bff273abc1faf 100644 --- a/crates/language_model/src/provider.rs +++ b/crates/language_model/src/provider.rs @@ -6,3 +6,4 @@ pub mod fake; pub mod google; pub mod ollama; pub mod open_ai; +pub mod bedrock; diff --git a/crates/language_model/src/provider/bedrock.rs b/crates/language_model/src/provider/bedrock.rs new file mode 100644 index 0000000000000..e0895a99068ef --- /dev/null +++ b/crates/language_model/src/provider/bedrock.rs @@ -0,0 +1,645 @@ +use crate::{ + settings::AllLanguageModelSettings, LanguageModel, LanguageModelCacheConfiguration, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, +}; +use crate::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; +use anyhow::{anyhow, Context as _, Result}; +use aws_config as config; +use aws_config::meta::credentials::CredentialsProviderChain; +use aws_config::Region; +use aws_credential_types::Credentials; +use aws_sdk_bedrockruntime as bedrock_client; +use aws_sdk_bedrockruntime::Config; +// use bedrock::{BedrockError, ContentDelta, Event, ResponseContent}; +use collections::{BTreeMap, HashMap}; +use editor::{Editor, EditorElement, EditorStyle}; +use futures::Stream; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _}; +use gpui::{AnyView, AppContext, AsyncAppContext, FontStyle, Model, ModelContext, Subscription, Task, TextStyle, View, WhiteSpace}; +use http_client::HttpClient; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsStore}; +use std::pin::Pin; +use std::str::FromStr; +use std::{sync::Arc, time::Duration}; +use serde_json::Value; +use strum::IntoEnumIterator; +use bedrock::BedrockError; +use theme::ThemeSettings; +use ui::{prelude::*, Icon, IconName, Tooltip}; +use util::{maybe, ResultExt}; +use crate::provider::anthropic::{count_anthropic_tokens, map_to_language_model_completion_events}; + +const PROVIDER_ID : &str = "amazon-bedrock"; +const PROVIDER_NAME : &str = "Amazon Bedrock"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct AmazonBedrockSettings { + pub region: Option, + pub credentials: Option, + pub available_models: Vec +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + pub name: String, + pub display_name: Option, + pub max_tokens: usize, + pub cache_configuration: Option, + pub max_output_tokens: Option, + pub default_temperature: Option, +} + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct AmazonBedrockCredentials { + pub access_key_id: String, + pub secret_access_key: String, + pub session_token: Option, +} + +// Different because we don't want to overwrite their AWS credentials +const ZED_BEDROCK_AAID: &str = "ZED_ACCESS_KEY_ID"; +const ZED_BEDROCK_SK: &str = "ZED_SECRET_ACCESS_KEY"; +const ZED_BEDROCK_REGION: &str = "ZED_AWS_REGION"; + +pub struct State { + credentials: Option, + credentials_from_env: bool, + region: Option, + _subscription: Subscription +} + + +pub struct BedrockLanguageModelProvider { + runtime_client: bedrock_client::Client, + state: gpui::Model +} + +impl State { + fn reset_credentials(&self, cx: &mut ModelContext) -> Task> { + let delete_aa_id= + cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).bedrock.credentials.clone().unwrap().access_key_id); + let delete_sk: Task> = + cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).bedrock.credentials.clone().unwrap().secret_access_key); + cx.spawn(|this, mut cx| async move { + delete_aa_id.await.ok(); + delete_sk.await.ok(); + this.update(&mut cx, |this, cx| { + this.credentials = None; + this.credentials_from_env = false; + cx.notify(); + }) + }) + } + + fn set_credentials(&mut self, access_key_id: String, secret_key: String, region: String, cx: &mut ModelContext) -> Task> { + let write_aa_id = cx.write_credentials( + ZED_BEDROCK_AAID, // TODO: GET THIS REVIEWED, MAKE SURE IT DOESN'T BREAK STUFF LONG TERM + "Bearer", + access_key_id.as_bytes() + ); + let write_sk = cx.write_credentials( + ZED_BEDROCK_SK, // TODO: GET THIS REVIEWED, MAKE SURE IT DOESN'T BREAK STUFF LONG TERM + "Bearer", + secret_key.as_bytes() + ); + let write_region = cx.write_credentials( + ZED_BEDROCK_REGION, + "Bearer", + region.as_bytes() + ); + cx.spawn(|this, mut cx| async move { + write_aa_id.await?; + write_sk.await?; + write_region.await?; + + this.update(&mut cx, |this, cx| { + this.credentials = Some(AmazonBedrockCredentials{ + access_key_id, + secret_access_key: secret_key, + session_token: None + }); + this.region = Some(region); + cx.notify(); + }) + }) + } + + fn is_authenticated(&self) -> bool { + self.credentials.is_some() + } + + fn authenticate(&self, cx: &mut ModelContext) -> Task> { + // just hit the sdk-bedrock list models to check if the credentials are valid + if self.is_authenticated() { + Task::ready(Ok(())) + } else { + cx.spawn(|this, mut cx| async move { + let (aa_id, sk, region, from_env) = if let (Ok(aa_id), Ok(sk), Ok(region)) + = (std::env::var(ZED_BEDROCK_AAID), std::env::var(ZED_BEDROCK_SK), std::env::var(ZED_BEDROCK_REGION)) + { + (aa_id, sk, region, true) + } else { + let (_, aa_id) = cx + .update(| cx | cx.read_credentials(ZED_BEDROCK_AAID))? + .await? + .ok_or_else(|| anyhow!("Access key ID not found"))?; + let (_, sk) = cx + .update(| cx | cx.read_credentials(ZED_BEDROCK_SK))? + .await? + .ok_or_else(|| anyhow!("Secret access key not found"))?; + let (_, region) = cx + .update(| cx | cx.read_credentials(ZED_BEDROCK_REGION))? + .await? + .ok_or_else(|| anyhow!("Region not found"))?; + + + (String::from_utf8(aa_id)?, String::from_utf8(sk)?, String::from_utf8(region)?, false) + }; + + this.update(&mut cx, |this, cx| { + this.credentials_from_env = from_env; + this.credentials = Some(AmazonBedrockCredentials{ + access_key_id: aa_id, + secret_access_key: sk, + session_token: None + }); + this.region = Some(region); + cx.notify(); + }) + }) + } + + } +} + +impl BedrockLanguageModelProvider { + pub fn new(cx: &mut AppContext) -> Self { + + let state = cx.new_model(|cx| State { + credentials: None, + region: Some(String::from("us-east-1")), + credentials_from_env: false, + _subscription: cx.observe_global::(|_, cx| { + cx.notify(); + }) + }); + + let region_def: String = state.read(cx).region.clone() + .or_else(|| {Some(String::from("us-east-1"))}) + .unwrap(); + let creds_clone = &state.read(cx).credentials.clone() + .or_else(|| { Some(AmazonBedrockCredentials::default()) }) + .unwrap(); + + let client_config = Config::builder() + .region(Region::new(region_def)) + .credentials_provider(Credentials::from_keys( + &creds_clone.clone().access_key_id, + &creds_clone.clone().secret_access_key, + creds_clone.clone().session_token, + )) + .build(); + + let runtime_client = bedrock_client::Client::from_conf(client_config); + + Self { + runtime_client, + state + } + } +} + +struct BedrockModel { + id: LanguageModelId, + model: bedrock::Model, + state: Model, +} + +impl BedrockModel { + fn stream_completion( + &self, + request: bedrock::Request, + cx: &AsyncAppContext + ) -> BoxFuture<'static, Result>>> { + todo!() + } +} + +impl LanguageModel for BedrockModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn telemetry_id(&self) -> String { + format!("anthropic/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + Some(self.model.max_output_tokens()) + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + count_anthropic_tokens(request, cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let request = request.into_bedrock( + self.model.id().into(), + self.model.default_temperature(), + self.model.max_output_tokens(), + ); + + let request = self.stream_completion(request, cx); + let future = self.request_limiter.stream(async move { + let response = request.await.map_err(|err| anyhow!(err))?; + Ok(map_to_language_model_completion_events(response)) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + + fn cache_configuration(&self) -> Option { + self.model + .cache_configuration() + .map(|config| LanguageModelCacheConfiguration { + max_cache_anchors: config.max_cache_anchors, + should_speculate: config.should_speculate, + min_total_token: config.min_total_token, + }) + } + + fn use_any_tool( + &self, + request: LanguageModelRequest, + tool_name: String, + tool_description: String, + input_schema: serde_json::Value, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let mut request = request.into_anthropic( + self.model.tool_model_id().into(), + self.model.default_temperature(), + self.model.max_output_tokens(), + ); + request.tool_choice = Some(anthropic::ToolChoice::Tool { + name: tool_name.clone(), + }); + request.tools = vec![anthropic::Tool { + name: tool_name.clone(), + description: tool_description, + input_schema, + }]; + + let response = self.stream_completion(request, cx); + self.request_limiter + .run(async move { + let response = response.await?; + Ok(anthropic::extract_tool_args_from_events( + tool_name, + Box::pin(response.map_err(|e| anyhow!(e))), + ) + .await? + .boxed()) + }) + .boxed() + } +} + +impl LanguageModelProvider for BedrockLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn provided_models(&self, cx: &AppContext) -> Vec> { + let mut models = BTreeMap::default(); + + for model in bedrock::Model::iter() { + if !matches!(model, bedrock::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + // Override with available models from settings + for model in AllLanguageModelSettings::get_global(cx) + .bedrock + .available_models + .iter() + { + models.insert( + model.name.clone(), + bedrock::Model::Custom { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + cache_configuration: model.cache_configuration.as_ref().map(|config| { + bedrock::BedrockModelCacheConfiguration { + max_cache_anchors: config.max_cache_anchors, + should_speculate: config.should_speculate, + min_total_token: config.min_total_token, + } + }), + max_output_tokens: model.max_output_tokens, + default_temperature: model.default_temperature, + }, + ); + } + + models + .into_values() + .map(|model| { + Arc::new(BedrockModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + }) as Arc + }) + .collect() + } + + fn is_authenticated(&self, cx: &AppContext) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut AppContext) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut AppContext) -> Task> { + self.state.update(cx, |state, cx| state.reset_credentials(cx)) + } +} + +impl LanguageModelProviderState for BedrockLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +struct ConfigurationView { + access_key_id_editor: View, + secret_access_key_editor: View, + region_editor: View, + state: Model, + load_credentials_task: Option>, +} + +impl ConfigurationView { + const PLACEHOLDER_TEXT: &'static str = "XXXXXXXXXXXXXXXXXXX"; + + fn new(state: gpui::Model, cx: &mut ViewContext) -> Self { + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn({ + let state = state.clone(); + |this, mut cx| async move { + if let Some(task) = state + .update(&mut cx, |state, cx| state.authenticate(cx)) + .log_err() + { + // We don't log an error, because "not signed in" is also an error. + let _ = task.await; + } + this.update(&mut cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + access_key_id_editor: cx.new_view(|cx| { + let mut editor = Editor::single_line(cx); + editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx); + editor + }), + secret_access_key_editor: cx.new_view(|cx| { + let mut editor = Editor::single_line(cx); + editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx); + editor + }), + region_editor: cx.new_view(|cx| { + let mut editor = Editor::single_line(cx); + editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx); + editor + }), + state, + load_credentials_task, + } + } + + fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + let access_key_id = self.access_key_id_editor + .read(cx) + .text(cx) + .to_string(); + let secret_access_key = self.secret_access_key_editor + .read(cx) + .text(cx) + .to_string(); + let region = self.region_editor + .read(cx) + .text(cx) + .to_string(); + + let state = self.state.clone(); + cx.spawn(|_, mut cx| async move { + state + .update(&mut cx, |state, cx| state.set_credentials(access_key_id, secret_access_key, region, cx))? + .await + }) + .detach_and_log_err(cx); + } + + fn reset_credentials(&mut self, cx: &mut ViewContext) { + self.access_key_id_editor.update(cx, |editor, cx| editor.set_text("", cx)); + self.secret_access_key_editor.update(cx, |editor, cx| editor.set_text("", cx)); + self.region_editor.update(cx, |editor, cx| editor.set_text("", cx)); + + let state = self.state.clone(); + cx.spawn(|_, mut cx| async move { + state + .update(&mut cx, |state, cx| state.reset_credentials(cx))? + .await + }) + .detach_and_log_err(cx); + } + + fn make_text_style(&self, cx: &ViewContext) -> TextStyle { + let settings = ThemeSettings::get_global(cx); + TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_fallbacks: settings.ui_font.fallbacks.clone(), + font_size: rems(0.875).into(), + font_weight: settings.ui_font.weight, + font_style: FontStyle::Normal, + line_height: relative(1.3), + background_color: None, + underline: None, + strikethrough: None, + white_space: WhiteSpace::Normal, + truncate: None, + } + } + + fn render_aa_id_editor(&self, cx: &mut ViewContext) -> impl IntoElement { + let text_style = self.make_text_style(cx); + + EditorElement::new( + &self.access_key_id_editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } + + fn render_sk_editor(&self, cx: &mut ViewContext) -> impl IntoElement { + let text_style = self.make_text_style(cx); + + EditorElement::new( + &self.secret_access_key_editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } + + fn render_region_editor(&self, cx: &mut ViewContext) -> impl IntoElement { + let text_style = self.make_text_style(cx); + + EditorElement::new( + &self.region_editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } + + fn should_render_editor(&self, cx: &mut ViewContext) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + const IAM_CONSOLE_URL: &str = "https://us-east-1.console.aws.amazon.com/iam/home"; + const INSTRUCTIONS: [&str; 3] = [ + "To use Zed's assistant with Bedrock, you need to add the Access Key ID. and Secret Access Key. Follow these steps:", + "- Create a pair at:", + "- Paste your Access Key ID and Secret Key below and hit enter to use the assistant:", + ]; + let env_var_set = self.state.read(cx).credentials_from_env; + + if self.load_credentials_task.is_some() { + div().child(Label::new("Loading credentials...")).into_any() + } else if self.should_render_editor(cx) { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_credentials)) + .child(Label::new(INSTRUCTIONS[0])) + .child(h_flex().child(Label::new(INSTRUCTIONS[1])).child( + Button::new("anthropic_console", IAM_CONSOLE_URL) + .style(ButtonStyle::Subtle) + .icon(IconName::ExternalLink) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, cx| cx.open_url(IAM_CONSOLE_URL)) + ) + ) + .child(Label::new(INSTRUCTIONS[2])) + .child( + h_flex() + .gap_1() + .child(self.render_aa_id_editor(cx)) + .child(self.render_sk_editor(cx)) + .child(self.render_region_editor(cx)) + ) + .child( + Label::new( + format!("You can also assign the {ZED_BEDROCK_AAID} and {ZED_BEDROCK_SK} environment variable and restart Zed."), + ) + .size(LabelSize::Small), + ) + .into_any() + } else { + h_flex() + .size_full() + .justify_between() + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new(if env_var_set { + format!("Access Key ID is set in {ZED_BEDROCK_AAID}, Secret Key is set in {ZED_BEDROCK_SK} environment variables.") + } else { + "Credentials configured.".to_string() + })), + ) + .child( + Button::new("reset-key", "Reset key") + .icon(Some(IconName::Trash)) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .disabled(env_var_set) + .when(env_var_set, |this| { + this.tooltip(|cx| Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_AAID} and {ZED_BEDROCK_SK} environment variables."), cx)) + }) + .on_click(cx.listener(|this, _, cx| this.reset_credentials(cx))), + ) + .into_any() + } + } +} + + diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 72dfd998d4bb2..95d78ef6b0431 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -13,6 +13,7 @@ use collections::BTreeMap; use gpui::{AppContext, EventEmitter, Global, Model, ModelContext}; use std::sync::Arc; use ui::Context; +use crate::provider::bedrock::BedrockLanguageModelProvider; pub fn init(user_store: Model, client: Arc, cx: &mut AppContext) { let registry = cx.new_model(|cx| { @@ -33,6 +34,11 @@ fn register_language_model_providers( RefreshLlmTokenListener::register(client.clone(), cx); + registry.register_provider( + BedrockLanguageModelProvider::new(cx), + cx + ); + registry.register_provider( AnthropicLanguageModelProvider::new(client.http_client(), cx), cx, diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 06dde1862ab37..b51b789a1b0ff 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -410,6 +410,15 @@ impl LanguageModelRequest { top_p: None, } } + + pub fn into_bedrock( + self, + model: String, + default_temperature: f32, + max_output_tokens: u32 + ) -> bedrock::Request { + todo!(); + } } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index 2bf8deb04238c..c521e911e58bc 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -19,6 +19,7 @@ use crate::{ }, LanguageModelCacheConfiguration, }; +use crate::provider::bedrock::{AmazonBedrockCredentials, AmazonBedrockSettings}; /// Initializes the language model settings. pub fn init(fs: Arc, cx: &mut AppContext) { @@ -55,6 +56,7 @@ pub fn init(fs: Arc, cx: &mut AppContext) { #[derive(Default)] pub struct AllLanguageModelSettings { + pub bedrock: AmazonBedrockSettings, pub anthropic: AnthropicSettings, pub ollama: OllamaSettings, pub openai: OpenAiSettings, @@ -65,6 +67,7 @@ pub struct AllLanguageModelSettings { #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct AllLanguageModelSettingsContent { + pub bedrock: Option, pub anthropic: Option, pub ollama: Option, pub openai: Option, @@ -74,6 +77,14 @@ pub struct AllLanguageModelSettingsContent { pub copilot_chat: Option, } +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct BedrockSettingsContent { + pub region: Option, + pub access_key_id: Option, + pub secret_access_key: Option, + pub available_models: Option> +} + #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[serde(untagged)] pub enum AnthropicSettingsContent { From 62f8ce3ac0df292d56b8de4198df7db9355da795 Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Mon, 28 Oct 2024 14:15:27 -0400 Subject: [PATCH 02/13] Compiles with the appropriate UI fields + saves credentials --- crates/bedrock/src/bedrock.rs | 14 -- crates/language_model/src/provider/bedrock.rs | 131 +++++++++++------- 2 files changed, 83 insertions(+), 62 deletions(-) diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index baba056760aae..8933d87fa4e32 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -37,8 +37,6 @@ pub enum Model { max_tokens: usize, /// The name displayed in the UI, such as in the assistant panel model dropdown menu. display_name: Option, - /// Override this model with a different Bedrock model for tool calls. - tool_override: Option, /// Indicates whether this custom model supports caching. cache_configuration: Option, max_output_tokens: Option, @@ -130,18 +128,6 @@ impl Model { } => default_temperature.unwrap_or(1.0), } } - - pub fn tool_model_id(&self) -> &str { - if let Self::Custom { - tool_override: Some(tool_override), - .. - } = self - { - tool_override - } else { - self.id() - } - } } pub async fn complete( diff --git a/crates/language_model/src/provider/bedrock.rs b/crates/language_model/src/provider/bedrock.rs index e0895a99068ef..8f60fbc02f14f 100644 --- a/crates/language_model/src/provider/bedrock.rs +++ b/crates/language_model/src/provider/bedrock.rs @@ -11,13 +11,11 @@ use aws_config::Region; use aws_credential_types::Credentials; use aws_sdk_bedrockruntime as bedrock_client; use aws_sdk_bedrockruntime::Config; -// use bedrock::{BedrockError, ContentDelta, Event, ResponseContent}; use collections::{BTreeMap, HashMap}; use editor::{Editor, EditorElement, EditorStyle}; use futures::Stream; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _}; use gpui::{AnyView, AppContext, AsyncAppContext, FontStyle, Model, ModelContext, Subscription, Task, TextStyle, View, WhiteSpace}; -use http_client::HttpClient; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; @@ -26,11 +24,10 @@ use std::str::FromStr; use std::{sync::Arc, time::Duration}; use serde_json::Value; use strum::IntoEnumIterator; -use bedrock::BedrockError; +use bedrock::{BedrockError, ContentDelta, Event, ResponseContent}; use theme::ThemeSettings; use ui::{prelude::*, Icon, IconName, Tooltip}; use util::{maybe, ResultExt}; -use crate::provider::anthropic::{count_anthropic_tokens, map_to_language_model_completion_events}; const PROVIDER_ID : &str = "amazon-bedrock"; const PROVIDER_NAME : &str = "Amazon Bedrock"; @@ -216,6 +213,7 @@ struct BedrockModel { id: LanguageModelId, model: bedrock::Model, state: Model, + request_limiter: RateLimiter, } impl BedrockModel { @@ -246,7 +244,7 @@ impl LanguageModel for BedrockModel { } fn telemetry_id(&self) -> String { - format!("anthropic/{}", self.model.id()) + format!("bedrock/{}", self.model.id()) } fn max_token_count(&self) -> usize { @@ -262,7 +260,7 @@ impl LanguageModel for BedrockModel { request: LanguageModelRequest, cx: &AppContext, ) -> BoxFuture<'static, Result> { - count_anthropic_tokens(request, cx) + get_bedrock_tokens(request, cx) } fn stream_completion( @@ -284,6 +282,10 @@ impl LanguageModel for BedrockModel { async move { Ok(future.await?.boxed()) }.boxed() } + fn use_any_tool(&self, request: LanguageModelRequest, name: String, description: String, schema: Value, cx: &AsyncAppContext) -> BoxFuture<'static, Result>>> { + unimplemented!(); + } + fn cache_configuration(&self) -> Option { self.model .cache_configuration() @@ -293,42 +295,73 @@ impl LanguageModel for BedrockModel { min_total_token: config.min_total_token, }) } +} + +fn get_bedrock_tokens(request: LanguageModelRequest, cx: &AppContext) -> BoxFuture<'static, Result> { - fn use_any_tool( - &self, - request: LanguageModelRequest, - tool_name: String, - tool_description: String, - input_schema: serde_json::Value, - cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result>>> { - let mut request = request.into_anthropic( - self.model.tool_model_id().into(), - self.model.default_temperature(), - self.model.max_output_tokens(), - ); - request.tool_choice = Some(anthropic::ToolChoice::Tool { - name: tool_name.clone(), - }); - request.tools = vec![anthropic::Tool { - name: tool_name.clone(), - description: tool_description, - input_schema, - }]; - - let response = self.stream_completion(request, cx); - self.request_limiter - .run(async move { - let response = response.await?; - Ok(anthropic::extract_tool_args_from_events( - tool_name, - Box::pin(response.map_err(|e| anyhow!(e))), - ) - .await? - .boxed()) - }) - .boxed() - } +} + +pub fn map_to_language_model_completion_events( + events: Pin>>>, +) -> impl Stream> { + struct State { + events: Pin>>> + } + + futures::stream::unfold( + State { + events + }, + |mut state: State| async move { + while let Some(event) = state.events.next().await { + match event { + Ok(event) => match event { + Event::ContentBlockStart { + index, + content_block, + } => match content_block { + ResponseContent::Text { text } => { + return Some(( + Some(Ok(LanguageModelCompletionEvent::Text(text))), + state, + )); + } + }, + Event::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => { + return Some(( + Some(Ok(LanguageModelCompletionEvent::Text(text))), + state, + )); + } + _ => {} + }, + Event::MessageDelta { delta, .. } => { + if let Some(stop_reason) = delta.stop_reason.as_deref() { + let stop_reason = match stop_reason { + "end_turn" => StopReason::EndTurn, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + _ => StopReason::EndTurn, + }; + + return Some(( + Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))), + state, + )); + } + } + _ => {} + }, + Err(err) => { + return Some((Some(Err(anyhow!(err))), state)); + } + } + } + + None + } + ).filter_map(|event| async move { event }) } impl LanguageModelProvider for BedrockLanguageModelProvider { @@ -381,6 +414,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { id: LanguageModelId::from(model.id().to_string()), model, state: self.state.clone(), + request_limiter: RateLimiter::new(4) }) as Arc }) .collect() @@ -422,6 +456,7 @@ struct ConfigurationView { impl ConfigurationView { const PLACEHOLDER_TEXT: &'static str = "XXXXXXXXXXXXXXXXXXX"; + const PLACEHOLDER_REGION: &'static str = "us-east-1"; fn new(state: gpui::Model, cx: &mut ViewContext) -> Self { cx.observe(&state, |_, _, cx| { @@ -460,7 +495,7 @@ impl ConfigurationView { }), region_editor: cx.new_view(|cx| { let mut editor = Editor::single_line(cx); - editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx); + editor.set_placeholder_text(Self::PLACEHOLDER_REGION, cx); editor }), state, @@ -575,9 +610,9 @@ impl Render for ConfigurationView { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { const IAM_CONSOLE_URL: &str = "https://us-east-1.console.aws.amazon.com/iam/home"; const INSTRUCTIONS: [&str; 3] = [ - "To use Zed's assistant with Bedrock, you need to add the Access Key ID. and Secret Access Key. Follow these steps:", + "To use Zed's assistant with Bedrock, you need to add the Access Key ID, Secret Access Key and AWS Region. Follow these steps:", "- Create a pair at:", - "- Paste your Access Key ID and Secret Key below and hit enter to use the assistant:", + "- Paste your Access Key ID, Secret Key, and Region below and hit enter to use the assistant:", ]; let env_var_set = self.state.read(cx).credentials_from_env; @@ -589,7 +624,7 @@ impl Render for ConfigurationView { .on_action(cx.listener(Self::save_credentials)) .child(Label::new(INSTRUCTIONS[0])) .child(h_flex().child(Label::new(INSTRUCTIONS[1])).child( - Button::new("anthropic_console", IAM_CONSOLE_URL) + Button::new("iam_console", IAM_CONSOLE_URL) .style(ButtonStyle::Subtle) .icon(IconName::ExternalLink) .icon_size(IconSize::XSmall) @@ -607,7 +642,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {ZED_BEDROCK_AAID} and {ZED_BEDROCK_SK} environment variable and restart Zed."), + format!("You can also assign the {ZED_BEDROCK_AAID}, {ZED_BEDROCK_SK} and {ZED_BEDROCK_REGION} environment variable and restart Zed."), ) .size(LabelSize::Small), ) @@ -621,7 +656,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("Access Key ID is set in {ZED_BEDROCK_AAID}, Secret Key is set in {ZED_BEDROCK_SK} environment variables.") + format!("Access Key ID is set in {ZED_BEDROCK_AAID}, Secret Key is set in {ZED_BEDROCK_SK}, Region is set in {ZED_BEDROCK_REGION} environment variables.") } else { "Credentials configured.".to_string() })), @@ -633,7 +668,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(|cx| Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_AAID} and {ZED_BEDROCK_SK} environment variables."), cx)) + this.tooltip(|cx| Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_AAID}, {ZED_BEDROCK_SK}, and {ZED_BEDROCK_REGION} environment variables."), cx)) }) .on_click(cx.listener(|this, _, cx| this.reset_credentials(cx))), ) From b146a6d9e489ede1bf09b47c4e14c8284e4423fc Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Thu, 31 Oct 2024 17:57:03 -0400 Subject: [PATCH 03/13] Starting the transpose to the Bedrock SDK --- crates/assistant/src/assistant_settings.rs | 5 + crates/bedrock/src/bedrock.rs | 122 +++++++++--------- crates/collab/src/llm.rs | 4 + crates/language_model/src/provider/bedrock.rs | 51 +++++++- crates/language_model/src/request.rs | 83 +++++++++++- crates/rpc/src/llm.rs | 1 + 6 files changed, 199 insertions(+), 67 deletions(-) diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 5a2d234f6f7ae..fdc9871fe8c42 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -50,6 +50,11 @@ pub enum AssistantProviderContentV1 { api_url: Option, low_speed_timeout_in_seconds: Option, }, + #[serde(rename = "bedrock")] + Bedrock { + default_model: Option, + region: Option, + }, } #[derive(Debug, Default)] diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index 8933d87fa4e32..1bf3d854191f6 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -2,14 +2,21 @@ use std::time::Duration; use std::{pin::Pin, str::FromStr}; use anyhow::{anyhow, Context, Result}; -use aws_sdk_bedrockruntime as bedrock; +use aws_sdk_bedrockruntime::types::ConverseStreamOutput; use chrono::{DateTime, Utc}; -use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, Stream, StreamExt}; use serde::{Deserialize, Serialize}; use strum::{EnumIter, EnumString}; use thiserror::Error; use util::ResultExt as _; +pub use aws_sdk_bedrockruntime as bedrock; + +pub use bedrock::operation::converse_stream::ConverseStreamInput as StreamingRequest; +pub use bedrock::Error as BedrockError; + +//TODO: Re-export the Bedrock stuff +// https://doc.rust-lang.org/rustdoc/write-documentation/re-exports.html #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] @@ -141,14 +148,57 @@ pub async fn complete( pub async fn stream_completion( client: &bedrock::Client, - api_url: &str, - api_key: &str, request: Request, low_speed_timeout: Option, ) -> Result>, BedrockError> { - todo!() + + let response = bedrock::Client::converse_stream(client) + .model_id(request.model) + .set_messages(request.messages.into()).send().await; + + let mut stream = match response { + Ok(output) => Ok(output.stream), + Err(e) => { + // TODO: Figure this out + unimplemented!(); + } + }; + + if stream.is_ok() { + let reader = BufReader::new(stream); + let stream = reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + let event: bedrock = get_converse_output_text(line); + Some(Ok(event)) + } + Err(e) => Some(Err(e.into())), + } + }).boxed(); + + Ok(stream) + } } +fn get_converse_output_text( + output: ConverseStreamOutput, +) -> Result { + Ok(match output { + ConverseStreamOutput::ContentBlockDelta(c) => { + match c.delta() { + Some(delta) => delta.as_text().cloned().unwrap_or_else(|_| "".into()), + None => "".into(), + } + } + _ => { + String::from("") + } + }) +} +//TODO: A LOT of these types need to re-export the Bedrock types instead of making custom ones + #[derive(Debug, Serialize, Deserialize, Copy, Clone)] #[serde(rename_all = "lowercase")] pub enum CacheControlType { @@ -161,18 +211,9 @@ pub struct CacheControl { pub cache_type: CacheControlType, } -#[derive(Debug, Serialize, Deserialize)] -pub struct Message { - pub role: Role, - pub content: Vec, -} - -#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, -} +pub use bedrock::types::Message; +pub use bedrock::types::ConversationRole; +pub use bedrock::types::ResponseStream; #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] @@ -225,13 +266,6 @@ pub struct Request { pub top_p: Option, } -#[derive(Debug, Serialize, Deserialize)] -struct StreamingRequest { - #[serde(flatten)] - pub base: Request, - pub stream: bool, -} - #[derive(Debug, Serialize, Deserialize)] pub struct Metadata { pub user_id: Option, @@ -264,30 +298,6 @@ pub struct Response { pub usage: Usage, } -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum Event { - #[serde(rename = "message_start")] - MessageStart { message: Response }, - #[serde(rename = "content_block_start")] - ContentBlockStart { - index: usize, - content_block: ResponseContent, - }, - #[serde(rename = "content_block_delta")] - ContentBlockDelta { index: usize, delta: ContentDelta }, - #[serde(rename = "content_block_stop")] - ContentBlockStop { index: usize }, - #[serde(rename = "message_delta")] - MessageDelta { delta: MessageDelta, usage: Usage }, - #[serde(rename = "message_stop")] - MessageStop, - #[serde(rename = "ping")] - Ping, - #[serde(rename = "error")] - Error { error: ApiError }, -} - #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] pub enum ContentDelta { @@ -302,19 +312,3 @@ pub struct MessageDelta { pub stop_reason: Option, pub stop_sequence: Option, } - -#[derive(Error, Debug)] -pub enum BedrockError { - // TODO: propagate the error message - #[error("an error occurred while interacting with the Bedrock API")] - ApiError(bedrock::Error), - #[error("{0}")] - Other(#[from] anyhow::Error), -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ApiError { - #[serde(rename = "type")] - pub error_type: String, - pub message: String, -} diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 654327c4637ad..c80c3cc8b244d 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -409,6 +409,10 @@ async fn perform_completion( }) .boxed() } + LanguageModelProvider::Bedrock => { + // TODO: implement this + unimplemented!() + } }; Ok(Response::new(Body::wrap_stream(TokenCountingStream { diff --git a/crates/language_model/src/provider/bedrock.rs b/crates/language_model/src/provider/bedrock.rs index 8f60fbc02f14f..b22819c0066a0 100644 --- a/crates/language_model/src/provider/bedrock.rs +++ b/crates/language_model/src/provider/bedrock.rs @@ -221,8 +221,8 @@ impl BedrockModel { &self, request: bedrock::Request, cx: &AsyncAppContext - ) -> BoxFuture<'static, Result>>> { - todo!() + ) -> BoxFuture<'static, Result>>> { + } } @@ -298,7 +298,54 @@ impl LanguageModel for BedrockModel { } fn get_bedrock_tokens(request: LanguageModelRequest, cx: &AppContext) -> BoxFuture<'static, Result> { + cx.background_executor() + .spawn(async move { + let messages = request.messages; + let mut tokens_from_images = 0; + let mut string_messages = Vec::with_capacity(messages.len()); + + for message in messages { + use crate::MessageContent; + + let mut string_contents = String::new(); + + for content in message.content { + match content { + MessageContent::Text(text) => { + string_contents.push_str(&text); + } + MessageContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + MessageContent::ToolUse(_tool_use) => { + unimplemented!(); + } + MessageContent::ToolResult(tool_result) => { + unimplemented!(); + } + } + } + + if !string_contents.is_empty() { + string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(string_contents), + name: None, + function_call: None, + }); + } + } + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) + .map(|tokens| tokens + tokens_from_images) + }) + .boxed() } pub fn map_to_language_model_completion_events( diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index b51b789a1b0ff..b5f1c7d1c2bad 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -417,7 +417,88 @@ impl LanguageModelRequest { default_temperature: f32, max_output_tokens: u32 ) -> bedrock::Request { - todo!(); + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in self.messages { + if message.contents_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + let cache_control = if message.cache { + Some(bedrock::CacheControl { + cache_type: bedrock::CacheControlType::Ephemeral, + }) + } else { + None + }; + let bedrock_message_content: Vec = message + .content + .into_iter() + .filter_map(|content| match content { + MessageContent::Text(text) => { + if !text.is_empty() { + Some(bedrock::RequestContent::Text { + text, + cache_control, + }) + } else { + None + } + } + MessageContent::Image(image) => { + Some(bedrock::RequestContent::Image { + source: bedrock::ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + cache_control, + }) + } + _ => { + unimplemented!() + } + }) + .collect(); + let bedrock_role = match message.role { + Role::User => bedrock::Role::User, + Role::Assistant => bedrock::Role::Assistant, + Role::System => unreachable!("System role should never occur here"), + }; + if let Some(last_message) = new_messages.last_mut() { + if last_message.role == bedrock_role { + last_message.content.extend(bedrock_message_content); + continue; + } + } + new_messages.push(bedrock::Message { + role: bedrock_role, + content: bedrock_message_content, + }); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.string_contents()); + } + } + } + + bedrock::Request { + model, + messages: new_messages, + max_tokens: max_output_tokens, + system: Some(system_message), + metadata: None, + stop_sequences: Vec::new(), + temperature: self.temperature.or(Some(default_temperature)), + top_k: None, + top_p: None, + } } } diff --git a/crates/rpc/src/llm.rs b/crates/rpc/src/llm.rs index 0a7510d891d35..216599f738fdb 100644 --- a/crates/rpc/src/llm.rs +++ b/crates/rpc/src/llm.rs @@ -14,6 +14,7 @@ pub enum LanguageModelProvider { Anthropic, OpenAi, Google, + Bedrock } #[derive(Debug, Serialize, Deserialize)] From 1240bad1a3623a15bf32e732c67bbf69c95c1ead Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Tue, 5 Nov 2024 15:33:23 -0500 Subject: [PATCH 04/13] Added Bedrock Error enum that transparently exposes SDK Errors as Bedrock errors, while also letting us include it --- crates/bedrock/src/bedrock.rs | 114 +++++++++++---------------- crates/language_model/src/request.rs | 30 +++---- 2 files changed, 56 insertions(+), 88 deletions(-) diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index 1bf3d854191f6..34c760d3b19b6 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -1,31 +1,27 @@ use std::time::Duration; use std::{pin::Pin, str::FromStr}; - -use anyhow::{anyhow, Context, Result}; -use aws_sdk_bedrockruntime::types::ConverseStreamOutput; +use std::any::Any; +use anyhow::{anyhow, Context, Error, Result}; +use aws_sdk_bedrockruntime::types::{ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStopEvent, ConverseStreamOutput}; use chrono::{DateTime, Utc}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, Stream, StreamExt}; use serde::{Deserialize, Serialize}; use strum::{EnumIter, EnumString}; use thiserror::Error; -use util::ResultExt as _; pub use aws_sdk_bedrockruntime as bedrock; - +use aws_sdk_bedrockruntime::config::http::HttpResponse; +use aws_sdk_bedrockruntime::operation::converse::{ConverseError, ConverseOutput}; pub use bedrock::operation::converse_stream::ConverseStreamInput as StreamingRequest; -pub use bedrock::Error as BedrockError; +pub use bedrock::types::ContentBlock as RequestContent; +pub use bedrock::types::ConverseOutput as Response; +pub use bedrock::types::Message; +pub use bedrock::types::ConversationRole; +pub use bedrock::types::ResponseStream; //TODO: Re-export the Bedrock stuff // https://doc.rust-lang.org/rustdoc/write-documentation/re-exports.html -#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] -pub struct BedrockModelCacheConfiguration { - pub min_total_token: usize, - pub should_speculate: bool, - pub max_cache_anchors: usize, -} - #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] pub enum Model { @@ -44,8 +40,6 @@ pub enum Model { max_tokens: usize, /// The name displayed in the UI, such as in the assistant panel model dropdown menu. display_name: Option, - /// Indicates whether this custom model supports caching. - cache_configuration: Option, max_output_tokens: Option, default_temperature: Option, }, @@ -88,21 +82,6 @@ impl Model { } } - pub fn cache_configuration(&self) -> Option { - match self { - Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(BedrockModelCacheConfiguration { - min_total_token: 2_048, - should_speculate: true, - max_cache_anchors: 4, - }), - Self::Custom { - cache_configuration, - .. - } => cache_configuration.clone(), - _ => None, - } - } - pub fn max_token_count(&self) -> usize { match self { Self::Claude3_5Sonnet @@ -139,18 +118,28 @@ impl Model { pub async fn complete( client: &bedrock::Client, - api_url: &str, - api_key: &str, request: Request, ) -> Result { - todo!() + let mut response = bedrock::Client::converse(client) + .model_id(request.model.clone()) + .set_messages(request.messages.into()) + .send().await.context("Failed to send request to Bedrock"); + + match response { + Ok(output) => { + Ok(output.into()) + } + Err(err) => { + Err(anyhow!(err)) + } + } } pub async fn stream_completion( client: &bedrock::Client, request: Request, low_speed_timeout: Option, -) -> Result>, BedrockError> { +) -> Result>, BedrockError> { // There is no generic Bedrock event Type? let response = bedrock::Client::converse_stream(client) .model_id(request.model) @@ -160,6 +149,7 @@ pub async fn stream_completion( Ok(output) => Ok(output.stream), Err(e) => { // TODO: Figure this out + unimplemented!(); } }; @@ -198,7 +188,6 @@ fn get_converse_output_text( }) } //TODO: A LOT of these types need to re-export the Bedrock types instead of making custom ones - #[derive(Debug, Serialize, Deserialize, Copy, Clone)] #[serde(rename_all = "lowercase")] pub enum CacheControlType { @@ -211,27 +200,6 @@ pub struct CacheControl { pub cache_type: CacheControlType, } -pub use bedrock::types::Message; -pub use bedrock::types::ConversationRole; -pub use bedrock::types::ResponseStream; - -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum RequestContent { - #[serde(rename = "text")] - Text { - text: String, - #[serde(skip_serializing_if = "Option::is_none")] - cache_control: Option, - }, - #[serde(rename = "image")] - Image { - source: ImageSource, - #[serde(skip_serializing_if = "Option::is_none")] - cache_control: Option, - } -} - #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] pub enum ResponseContent { @@ -283,21 +251,27 @@ pub struct Usage { pub cache_read_input_tokens: Option, } -#[derive(Debug, Serialize, Deserialize)] -pub struct Response { - pub id: String, - #[serde(rename = "type")] - pub response_type: String, - pub role: Role, - pub content: Vec, - pub model: String, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub stop_reason: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub stop_sequence: Option, - pub usage: Usage, +#[derive(Error, Debug)] +pub enum BedrockError { + SdkError(bedrock::Error), + Other(anyhow::Error) } +// #[derive(Debug, Serialize, Deserialize)] +// pub struct Response { +// pub id: String, +// #[serde(rename = "type")] +// pub response_type: String, +// pub role: ConversationRole, +// pub content: Vec, +// pub model: String, +// #[serde(default, skip_serializing_if = "Option::is_none")] +// pub stop_reason: Option, +// #[serde(default, skip_serializing_if = "Option::is_none")] +// pub stop_sequence: Option, +// pub usage: Usage, +// } + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] pub enum ContentDelta { diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index b5f1c7d1c2bad..1cadbebc3327f 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -1,9 +1,9 @@ use std::io::{Cursor, Write}; - +use aws_sdk_bedrockruntime::types::ImageBlock; use crate::role::Role; use crate::LanguageModelToolUse; use base64::write::EncoderWriter; -use gpui::{point, size, AppContext, DevicePixels, Image, ObjectFit, RenderImage, Size, Task}; +use gpui::{point, size, AppContext, DevicePixels, Image, ImageFormat, ObjectFit, RenderImage, Size, Task}; use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder}; use serde::{Deserialize, Serialize}; use ui::{px, SharedString}; @@ -440,23 +440,17 @@ impl LanguageModelRequest { .filter_map(|content| match content { MessageContent::Text(text) => { if !text.is_empty() { - Some(bedrock::RequestContent::Text { - text, - cache_control, - }) + Some(bedrock::RequestContent::Text(text)) } else { None } } MessageContent::Image(image) => { - Some(bedrock::RequestContent::Image { - source: bedrock::ImageSource { - source_type: "base64".to_string(), - media_type: "image/png".to_string(), - data: image.source.to_string(), - }, - cache_control, - }) + todo!() + // Some(bedrock::RequestContent::Image(ImageBlock{ + // format: ImageFormat::, + // source: None, + // }) } _ => { unimplemented!() @@ -464,9 +458,9 @@ impl LanguageModelRequest { }) .collect(); let bedrock_role = match message.role { - Role::User => bedrock::Role::User, - Role::Assistant => bedrock::Role::Assistant, - Role::System => unreachable!("System role should never occur here"), + Role::User => bedrock::ConversationRole::User, + Role::Assistant => bedrock::ConversationRole::Assistant, + Role::System => unreachable!("System role should never occur here") }; if let Some(last_message) = new_messages.last_mut() { if last_message.role == bedrock_role { @@ -476,7 +470,7 @@ impl LanguageModelRequest { } new_messages.push(bedrock::Message { role: bedrock_role, - content: bedrock_message_content, + content: bedrock_message_content }); } Role::System => { From c7235219b5c0858b1f0665044648cc4482a4c39c Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Tue, 5 Nov 2024 22:58:53 -0500 Subject: [PATCH 05/13] Had Q Dev rewrite most of the model enums based on the AWS Bedrock list of FMs --- crates/bedrock/src/bedrock.rs | 111 +------------------ crates/bedrock/src/models.rs | 195 ++++++++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+), 108 deletions(-) create mode 100644 crates/bedrock/src/models.rs diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index 34c760d3b19b6..60ffc99712634 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -1,3 +1,5 @@ +mod models; + use std::time::Duration; use std::{pin::Pin, str::FromStr}; use std::any::Any; @@ -22,99 +24,7 @@ pub use bedrock::types::ResponseStream; //TODO: Re-export the Bedrock stuff // https://doc.rust-lang.org/rustdoc/write-documentation/re-exports.html -#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] -pub enum Model { - #[default] - #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")] - Claude3_5Sonnet, - #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")] - Claude3Opus, - #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")] - Claude3Sonnet, - #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-latest")] - Claude3Haiku, - #[serde(rename = "custom")] - Custom { - name: String, - max_tokens: usize, - /// The name displayed in the UI, such as in the assistant panel model dropdown menu. - display_name: Option, - max_output_tokens: Option, - default_temperature: Option, - }, -} - -impl Model { - pub fn from_id(id: &str) -> Result { - if id.starts_with("claude-3-5-sonnet") { - Ok(Self::Claude3_5Sonnet) - } else if id.starts_with("claude-3-opus") { - Ok(Self::Claude3Opus) - } else if id.starts_with("claude-3-sonnet") { - Ok(Self::Claude3Sonnet) - } else if id.starts_with("claude-3-haiku") { - Ok(Self::Claude3Haiku) - } else { - Err(anyhow!("invalid model id")) - } - } - - pub fn id(&self) -> &str { - match self { - Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest", - Model::Claude3Opus => "claude-3-opus-latest", - Model::Claude3Sonnet => "claude-3-sonnet-latest", - Model::Claude3Haiku => "claude-3-haiku-latest", - Self::Custom { name, .. } => name, - } - } - - pub fn display_name(&self) -> &str { - match self { - Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", - Self::Claude3Opus => "Claude 3 Opus", - Self::Claude3Sonnet => "Claude 3 Sonnet", - Self::Claude3Haiku => "Claude 3 Haiku", - Self::Custom { - name, display_name, .. - } => display_name.as_ref().unwrap_or(name), - } - } - - pub fn max_token_count(&self) -> usize { - match self { - Self::Claude3_5Sonnet - | Self::Claude3Opus - | Self::Claude3Sonnet - | Self::Claude3Haiku => 200_000, - Self::Custom { max_tokens, .. } => *max_tokens, - } - } - - pub fn max_output_tokens(&self) -> u32 { - match self { - Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096, - Self::Claude3_5Sonnet => 8_192, - Self::Custom { - max_output_tokens, .. - } => max_output_tokens.unwrap_or(4_096), - } - } - - pub fn default_temperature(&self) -> f32 { - match self { - Self::Claude3_5Sonnet - | Self::Claude3Opus - | Self::Claude3Sonnet - | Self::Claude3Haiku => 1.0, - Self::Custom { - default_temperature, - .. - } => default_temperature.unwrap_or(1.0), - } - } -} +pub use models::*; pub async fn complete( client: &bedrock::Client, @@ -257,21 +167,6 @@ pub enum BedrockError { Other(anyhow::Error) } -// #[derive(Debug, Serialize, Deserialize)] -// pub struct Response { -// pub id: String, -// #[serde(rename = "type")] -// pub response_type: String, -// pub role: ConversationRole, -// pub content: Vec, -// pub model: String, -// #[serde(default, skip_serializing_if = "Option::is_none")] -// pub stop_reason: Option, -// #[serde(default, skip_serializing_if = "Option::is_none")] -// pub stop_sequence: Option, -// pub usage: Usage, -// } - #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] pub enum ContentDelta { diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs new file mode 100644 index 0000000000000..a2acb8b9cf0e0 --- /dev/null +++ b/crates/bedrock/src/models.rs @@ -0,0 +1,195 @@ +use anyhow::anyhow; +use serde::{Deserialize, Serialize}; +use strum::EnumIter; + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] +pub enum Model { + #[default] + #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")] + Claude3_5Sonnet, + #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")] + Claude3Opus, + #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")] + Claude3Sonnet, + #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")] + Claude3_5Haiku, + // AI21 models + AI21J2GrandeInstruct, + AI21J2JumboInstruct, + AI21J2Mid, + AI21J2MidV1, + AI21J2Ultra, + AI21J2UltraV1_8k, + AI21J2UltraV1, + AI21JambaInstructV1, + AI21Jamba15LargeV1, + AI21Jamba15MiniV1, + // Anthropic models (already included) + // Cohere models + CohereCommandTextV14_4k, + CohereCommandRV1, + CohereCommandRPlusV1, + CohereCommandLightTextV14_4k, + // Meta models + MetaLlama38BInstructV1, + MetaLlama370BInstructV1, + MetaLlama318BInstructV1_128k, + MetaLlama318BInstructV1, + MetaLlama3170BInstructV1_128k, + MetaLlama3170BInstructV1, + MetaLlama3211BInstructV1, + MetaLlama3290BInstructV1, + MetaLlama321BInstructV1, + MetaLlama323BInstructV1, + // Mistral models + MistralMistral7BInstructV0, + MistralMixtral8x7BInstructV0, + MistralMistralLarge2402V1, + MistralMistralSmall2402V1, + #[serde(rename = "custom")] + Custom { + name: String, + max_tokens: usize, + /// The name displayed in the UI, such as in the assistant panel model dropdown menu. + display_name: Option, + max_output_tokens: Option, + default_temperature: Option, + }, +} + + +impl Model { + pub fn from_id(id: &str) -> anyhow::Result { + if id.starts_with("claude-3-5-sonnet") { + Ok(Self::Claude3_5Sonnet) + } else if id.starts_with("claude-3-opus") { + Ok(Self::Claude3Opus) + } else if id.starts_with("claude-3-sonnet") { + Ok(Self::Claude3Sonnet) + } else if id.starts_with("claude-3-5-haiku") { + Ok(Self::Claude3_5Haiku) + } else { + Err(anyhow!("invalid model id")) + } + } + + pub fn id(&self) -> &str { + match self { + Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20241022-v2:0", + Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0", + Model::Claude3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0", + Model::Claude3_5Haiku => "anthropic.claude-3-5-haiku-20241022-v1:0", + Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct", + Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct", + Model::AI21J2Mid => "ai21.j2-mid", + Model::AI21J2MidV1 => "ai21.j2-mid-v1", + Model::AI21J2Ultra => "ai21.j2-ultra", + Model::AI21J2UltraV1_8k => "ai21.j2-ultra-v1:0:8k", + Model::AI21J2UltraV1 => "ai21.j2-ultra-v1", + Model::AI21JambaInstructV1 => "ai21.jamba-instruct-v1:0", + Model::AI21Jamba15LargeV1 => "ai21.jamba-1-5-large-v1:0", + Model::AI21Jamba15MiniV1 => "ai21.jamba-1-5-mini-v1:0", + Model::CohereCommandTextV14_4k => "cohere.command-text-v14:7:4k", + Model::CohereCommandRV1 => "cohere.command-r-v1:0", + Model::CohereCommandRPlusV1 => "cohere.command-r-plus-v1:0", + Model::CohereCommandLightTextV14_4k => "cohere.command-light-text-v14:7:4k", + Model::MetaLlama38BInstructV1 => "meta.llama3-8b-instruct-v1:0", + Model::MetaLlama370BInstructV1 => "meta.llama3-70b-instruct-v1:0", + Model::MetaLlama318BInstructV1_128k => "meta.llama3-1-8b-instruct-v1:0:128k", + Model::MetaLlama318BInstructV1 => "meta.llama3-1-8b-instruct-v1:0", + Model::MetaLlama3170BInstructV1_128k => "meta.llama3-1-70b-instruct-v1:0:128k", + Model::MetaLlama3170BInstructV1 => "meta.llama3-1-70b-instruct-v1:0", + Model::MetaLlama3211BInstructV1 => "meta.llama3-2-11b-instruct-v1:0", + Model::MetaLlama3290BInstructV1 => "meta.llama3-2-90b-instruct-v1:0", + Model::MetaLlama321BInstructV1 => "meta.llama3-2-1b-instruct-v1:0", + Model::MetaLlama323BInstructV1 => "meta.llama3-2-3b-instruct-v1:0", + Model::MistralMistral7BInstructV0 => "mistral.mistral-7b-instruct-v0:2", + Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1", + Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0", + Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0", + Self::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", + Self::Claude3Opus => "Claude 3 Opus", + Self::Claude3Sonnet => "Claude 3 Sonnet", + Self::Claude3_5Haiku => "Claude 3.5 Haiku", + Self::Custom { + name, display_name, .. + } => display_name.as_ref().unwrap_or(name), + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Self::Claude3_5Sonnet + | Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3_5Haiku => 200_000, + Self::Custom { max_tokens, .. } => *max_tokens, + } + } + + pub fn max_output_tokens(&self) -> u32 { + match self { + Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096, + Self::Claude3_5Sonnet => 8_192, + Self::Custom { + max_output_tokens, .. + } => max_output_tokens.unwrap_or(4_096), + } + } + + pub fn default_temperature(&self) -> f32 { + match self { + Self::Claude3_5Sonnet + | Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3_5Haiku => 1.0, + Self::Custom { + default_temperature, + .. + } => default_temperature.unwrap_or(1.0), + } + } +} + +/** +"ai21.j2-grande-instruct" +"ai21.j2-jumbo-instruct" +"ai21.j2-mid" +"ai21.j2-mid-v1" +"ai21.j2-ultra" +"ai21.j2-ultra-v1:0:8k" +"ai21.j2-ultra-v1" +"ai21.jamba-instruct-v1:0" +"ai21.jamba-1-5-large-v1:0" +"ai21.jamba-1-5-mini-v1:0" +"anthropic.claude-3-sonnet-20240229-v1:0" +"anthropic.claude-3-haiku-20240307-v1:0" +"anthropic.claude-3-opus-20240229-v1:0" +"anthropic.claude-3-5-sonnet-20241022-v2:0" +"anthropic.claude-3-5-haiku-20241022-v1:0" +"cohere.command-text-v14:7:4k" +"cohere.command-r-v1:0" +"cohere.command-r-plus-v1:0" +"cohere.command-light-text-v14:7:4k" +"meta.llama3-8b-instruct-v1:0" +"meta.llama3-70b-instruct-v1:0" +"meta.llama3-1-8b-instruct-v1:0:128k" +"meta.llama3-1-8b-instruct-v1:0" +"meta.llama3-1-70b-instruct-v1:0:128k" +"meta.llama3-1-70b-instruct-v1:0" +"meta.llama3-2-11b-instruct-v1:0" +"meta.llama3-2-90b-instruct-v1:0" +"meta.llama3-2-1b-instruct-v1:0" +"meta.llama3-2-3b-instruct-v1:0" +"mistral.mistral-7b-instruct-v0:2" +"mistral.mixtral-8x7b-instruct-v0:1" +"mistral.mistral-large-2402-v1:0" +"mistral.mistral-small-2402-v1:0" +**/ \ No newline at end of file From 4fcc8ecad8ac2bfea14a97c6144cf5ba9c554748 Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Tue, 5 Nov 2024 23:04:34 -0500 Subject: [PATCH 06/13] Fixed the names in impl Model --- crates/bedrock/src/models.rs | 79 ++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index a2acb8b9cf0e0..fd004f2c8efab 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -118,9 +118,35 @@ impl Model { Self::Claude3Opus => "Claude 3 Opus", Self::Claude3Sonnet => "Claude 3 Sonnet", Self::Claude3_5Haiku => "Claude 3.5 Haiku", - Self::Custom { - name, display_name, .. - } => display_name.as_ref().unwrap_or(name), + Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct", + Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct", + Self::AI21J2Mid => "AI21 Jurassic2 Mid", + Self::AI21J2MidV1 => "AI21 Jurassic2 Mid V1", + Self::AI21J2Ultra => "AI21 Jurassic2 Ultra", + Self::AI21J2UltraV1_8k => "AI21 Jurassic2 Ultra V1 8K", + Self::AI21J2UltraV1 => "AI21 Jurassic2 Ultra V1", + Self::AI21JambaInstructV1 => "AI21 Jamba Instruct", + Self::AI21Jamba15LargeV1 => "AI21 Jamba 1.5 Large", + Self::AI21Jamba15MiniV1 => "AI21 Jamba 1.5 Mini", + Self::CohereCommandTextV14_4k => "Cohere Command Text V14 4K", + Self::CohereCommandRV1 => "Cohere Command R V1", + Self::CohereCommandRPlusV1 => "Cohere Command R Plus V1", + Self::CohereCommandLightTextV14_4k => "Cohere Command Light Text V14 4K", + Self::MetaLlama38BInstructV1 => "Meta Llama 3 8B Instruct V1", + Self::MetaLlama370BInstructV1 => "Meta Llama 3 70B Instruct V1", + Self::MetaLlama318BInstructV1_128k => "Meta Llama 3 1.8B Instruct V1 128K", + Self::MetaLlama318BInstructV1 => "Meta Llama 3 1.8B Instruct V1", + Self::MetaLlama3170BInstructV1_128k => "Meta Llama 3 1 70B Instruct V1 128K", + Self::MetaLlama3170BInstructV1 => "Meta Llama 3 1 70B Instruct V1", + Self::MetaLlama3211BInstructV1 => "Meta Llama 3 2 11B Instruct V1", + Self::MetaLlama3290BInstructV1 => "Meta Llama 3 2 90B Instruct V1", + Self::MetaLlama321BInstructV1 => "Meta Llama 3 2 1B Instruct V1", + Self::MetaLlama323BInstructV1 => "Meta Llama 3 2 3B Instruct V1", + Self::MistralMistral7BInstructV0 => "Mistral 7B Instruct V0", + Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0", + Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1", + Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1", + Self::Custom { display_name, name, .. } => display_name.as_deref().unwrap_or(name), } } @@ -131,6 +157,9 @@ impl Model { | Self::Claude3Sonnet | Self::Claude3_5Haiku => 200_000, Self::Custom { max_tokens, .. } => *max_tokens, + _ => { + 200_000 + } } } @@ -141,6 +170,9 @@ impl Model { Self::Custom { max_output_tokens, .. } => max_output_tokens.unwrap_or(4_096), + _ => { + 4_096 + } } } @@ -154,42 +186,9 @@ impl Model { default_temperature, .. } => default_temperature.unwrap_or(1.0), + _ => { + 1.0 + } } } -} - -/** -"ai21.j2-grande-instruct" -"ai21.j2-jumbo-instruct" -"ai21.j2-mid" -"ai21.j2-mid-v1" -"ai21.j2-ultra" -"ai21.j2-ultra-v1:0:8k" -"ai21.j2-ultra-v1" -"ai21.jamba-instruct-v1:0" -"ai21.jamba-1-5-large-v1:0" -"ai21.jamba-1-5-mini-v1:0" -"anthropic.claude-3-sonnet-20240229-v1:0" -"anthropic.claude-3-haiku-20240307-v1:0" -"anthropic.claude-3-opus-20240229-v1:0" -"anthropic.claude-3-5-sonnet-20241022-v2:0" -"anthropic.claude-3-5-haiku-20241022-v1:0" -"cohere.command-text-v14:7:4k" -"cohere.command-r-v1:0" -"cohere.command-r-plus-v1:0" -"cohere.command-light-text-v14:7:4k" -"meta.llama3-8b-instruct-v1:0" -"meta.llama3-70b-instruct-v1:0" -"meta.llama3-1-8b-instruct-v1:0:128k" -"meta.llama3-1-8b-instruct-v1:0" -"meta.llama3-1-70b-instruct-v1:0:128k" -"meta.llama3-1-70b-instruct-v1:0" -"meta.llama3-2-11b-instruct-v1:0" -"meta.llama3-2-90b-instruct-v1:0" -"meta.llama3-2-1b-instruct-v1:0" -"meta.llama3-2-3b-instruct-v1:0" -"mistral.mistral-7b-instruct-v0:2" -"mistral.mixtral-8x7b-instruct-v0:1" -"mistral.mistral-large-2402-v1:0" -"mistral.mistral-small-2402-v1:0" -**/ \ No newline at end of file +} \ No newline at end of file From 2932c19175b989c15f4a9f9bf24327af6eeebae7 Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Fri, 15 Nov 2024 09:14:09 -0500 Subject: [PATCH 07/13] removed a whole host of unsupported types --- Cargo.lock | 88 ++++++++++++++++++-------- crates/bedrock/src/bedrock.rs | 116 +++++++++------------------------- 2 files changed, 91 insertions(+), 113 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 073f7b43b9651..c7c4fac882670 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1099,9 +1099,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.4.2" +version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2424565416eef55906f9f8cece2072b6b6a76075e3ff81483ebe938a89a4c05f" +checksum = "a10d5c055aa540164d9561a0e2e74ad30f0dcf7393c3a92f6733ddf9c5762468" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -1123,6 +1123,29 @@ dependencies = [ "uuid", ] +[[package]] +name = "aws-sdk-bedrockruntime" +version = "1.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b257117450ded4c42fda3da48b686bacc9b1eac59e44108e80ade005c4029c2" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes 1.7.2", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + [[package]] name = "aws-sdk-s3" version = "1.47.0" @@ -1227,9 +1250,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.3" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5df1b0fa6be58efe9d4ccc257df0a53b89cd8909e86591a13ca54817c87517be" +checksum = "5619742a0d8f253be760bfbb8e8e8368c69e3587e4637af5754e488a611499b1" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", @@ -1288,9 +1311,9 @@ dependencies = [ [[package]] name = "aws-smithy-eventstream" -version = "0.60.4" +version = "0.60.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6363078f927f612b970edf9d1903ef5cef9a64d1e8423525ebb1f0a1633c858" +checksum = "cef7d0a272725f87e51ba2bf89f8c21e4df61b9e49ae1ac367a6d69916ef7c90" dependencies = [ "aws-smithy-types", "bytes 1.7.2", @@ -1299,9 +1322,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.10" +version = "0.60.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01dbcb6e2588fd64cfb6d7529661b06466419e4c54ed1c62d6510d2d0350a728" +checksum = "5c8bc3e8fdc6b8d07d976e301c02fe553f72a39b7a9fea820e023268467d7ab6" dependencies = [ "aws-smithy-eventstream", "aws-smithy-runtime-api", @@ -1339,9 +1362,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ce695746394772e7000b39fe073095db6d45a862d0767dd5ad0ac0d7f8eb87" +checksum = "be28bd063fa91fd871d131fc8b68d7cd4c5fa0869bea68daca50dcb1cbd76be2" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -1383,9 +1406,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.4" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "273dcdfd762fae3e1650b8024624e7cd50e484e37abdab73a7a706188ad34543" +checksum = "4fbd94a32b3a7d55d3806fe27d98d3ad393050439dd05eb53ece36ec5e3d3510" dependencies = [ "base64-simd", "bytes 1.7.2", @@ -1559,6 +1582,25 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bedrock" +version = "0.1.0" +dependencies = [ + "anyhow", + "aws-config", + "aws-sdk-bedrockruntime", + "chrono", + "futures 0.3.30", + "http_client", + "schemars", + "serde", + "serde_json", + "strum 0.25.0", + "thiserror", + "tokio", + "util", +] + [[package]] name = "bigdecimal" version = "0.4.5" @@ -1591,7 +1633,7 @@ dependencies = [ "bitflags 2.6.0", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "lazy_static", "lazycell", "proc-macro2", @@ -5644,7 +5686,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.10", + "socket2 0.5.7", "tokio", "tower-service", "tracing", @@ -6341,7 +6383,11 @@ version = "0.1.0" dependencies = [ "anthropic", "anyhow", + "aws-config", + "aws-credential-types", + "aws-sdk-bedrockruntime", "base64 0.22.1", + "bedrock", "client", "collections", "copilot", @@ -6553,7 +6599,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -14228,7 +14274,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -14988,16 +15034,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec7a2a501ed189703dba8b08142f057e887dfc4b2cc4db2d343ac6376ba3e0b9" -[[package]] -name = "xtask" -version = "0.1.0" -dependencies = [ - "anyhow", - "cargo_metadata", - "cargo_toml", - "clap", -] - [[package]] name = "yaml-rust2" version = "0.8.1" diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index 60ffc99712634..e303324416db5 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -1,25 +1,21 @@ mod models; -use std::time::Duration; -use std::{pin::Pin, str::FromStr}; +use std::{str::FromStr}; use std::any::Any; -use anyhow::{anyhow, Context, Error, Result}; -use aws_sdk_bedrockruntime::types::{ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStopEvent, ConverseStreamOutput}; -use chrono::{DateTime, Utc}; +use anyhow::{Context, Result}; +use aws_sdk_bedrockruntime::types::{ConverseStreamOutput}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, Stream, StreamExt}; use serde::{Deserialize, Serialize}; -use strum::{EnumIter, EnumString}; use thiserror::Error; pub use aws_sdk_bedrockruntime as bedrock; -use aws_sdk_bedrockruntime::config::http::HttpResponse; -use aws_sdk_bedrockruntime::operation::converse::{ConverseError, ConverseOutput}; pub use bedrock::operation::converse_stream::ConverseStreamInput as StreamingRequest; pub use bedrock::types::ContentBlock as RequestContent; pub use bedrock::types::ConverseOutput as Response; pub use bedrock::types::Message; pub use bedrock::types::ConversationRole; pub use bedrock::types::ResponseStream; +pub use bedrock::types::ConverseStreamOutput as BedrockEvent; //TODO: Re-export the Bedrock stuff // https://doc.rust-lang.org/rustdoc/write-documentation/re-exports.html @@ -37,10 +33,10 @@ pub async fn complete( match response { Ok(output) => { - Ok(output.into()) + Ok(output.output.unwrap()) } Err(err) => { - Err(anyhow!(err)) + Err(BedrockError::Other(err)) } } } @@ -48,43 +44,42 @@ pub async fn complete( pub async fn stream_completion( client: &bedrock::Client, request: Request, - low_speed_timeout: Option, ) -> Result>, BedrockError> { // There is no generic Bedrock event Type? let response = bedrock::Client::converse_stream(client) .model_id(request.model) .set_messages(request.messages.into()).send().await; + let mut stream = match response { Ok(output) => Ok(output.stream), - Err(e) => { - // TODO: Figure this out - - unimplemented!(); - } - }; - - if stream.is_ok() { - let reader = BufReader::new(stream); - let stream = reader - .lines() - .filter_map(|line| async move { - match line { - Ok(line) => { - let event: bedrock = get_converse_output_text(line); - Some(Ok(event)) - } - Err(e) => Some(Err(e.into())), - } - }).boxed(); - - Ok(stream) + Err(e) => Err( + BedrockError::SdkError(e.as_service_error().unwrap()) + ), + }?; + + loop { + let token = stream.recv().await; + match token { + Ok(Some(text)) => { + let next = get_converse_output_text(text)?; + print!("{}", next); + Ok(()) + } + Ok(None) => break, + Err(e) => Err(e + .as_service_error() + .map(BedrockConverseStreamError::from) + .unwrap_or(BedrockConverseStreamError( + "Unknown error receiving stream".into(), + ))), + }? } } fn get_converse_output_text( output: ConverseStreamOutput, -) -> Result { +) -> Result { Ok(match output { ConverseStreamOutput::ContentBlockDelta(c) => { match c.delta() { @@ -98,32 +93,6 @@ fn get_converse_output_text( }) } //TODO: A LOT of these types need to re-export the Bedrock types instead of making custom ones -#[derive(Debug, Serialize, Deserialize, Copy, Clone)] -#[serde(rename_all = "lowercase")] -pub enum CacheControlType { - Ephemeral, -} - -#[derive(Debug, Serialize, Deserialize, Copy, Clone)] -pub struct CacheControl { - #[serde(rename = "type")] - pub cache_type: CacheControlType, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum ResponseContent { - #[serde(rename = "text")] - Text { text: String }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ImageSource { - #[serde(rename = "type")] - pub source_type: String, - pub media_type: String, - pub data: String, -} #[derive(Debug, Serialize, Deserialize)] pub struct Request { @@ -149,35 +118,8 @@ pub struct Metadata { pub user_id: Option, } -#[derive(Debug, Serialize, Deserialize)] -pub struct Usage { - #[serde(default, skip_serializing_if = "Option::is_none")] - pub input_tokens: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub output_tokens: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub cache_creation_input_tokens: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub cache_read_input_tokens: Option, -} - #[derive(Error, Debug)] pub enum BedrockError { SdkError(bedrock::Error), Other(anyhow::Error) } - -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum ContentDelta { - #[serde(rename = "text_delta")] - TextDelta { text: String }, - #[serde(rename = "input_json_delta")] - InputJsonDelta { partial_json: String }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct MessageDelta { - pub stop_reason: Option, - pub stop_sequence: Option, -} From 3ff200a37251167fb6ba385c78599b71b8694c95 Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Fri, 22 Nov 2024 23:34:23 -0500 Subject: [PATCH 08/13] seemingly migrated to the new way of managing language models --- Cargo.lock | 47 ++++++ crates/bedrock/src/bedrock.rs | 26 ++- crates/language_model/Cargo.toml | 2 - .../language_model/src/model/cloud_model.rs | 13 +- crates/language_model/src/model/mod.rs | 1 + crates/language_models/Cargo.toml | 3 + .../language_models/src/provider/bedrock.rs | 158 ++++-------------- crates/language_models/src/provider/cloud.rs | 22 ++- 8 files changed, 137 insertions(+), 135 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9881c23e84e93..3bb78680956a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1146,6 +1146,29 @@ dependencies = [ "uuid", ] +[[package]] +name = "aws-sdk-bedrockruntime" +version = "1.61.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30652d08e84639fa570b696070e974d3153724701f81f89112e90986c811c39d" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes 1.8.0", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + [[package]] name = "aws-sdk-kinesis" version = "1.51.0" @@ -1603,6 +1626,25 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bedrock" +version = "0.1.0" +dependencies = [ + "anyhow", + "aws-config", + "aws-sdk-bedrockruntime", + "chrono", + "futures 0.3.31", + "http_client", + "schemars", + "serde", + "serde_json", + "strum 0.25.0", + "thiserror 1.0.69", + "tokio", + "util", +] + [[package]] name = "bigdecimal" version = "0.4.6" @@ -6554,7 +6596,9 @@ version = "0.1.0" dependencies = [ "anthropic", "anyhow", + "aws-sdk-bedrockruntime", "base64 0.22.1", + "bedrock", "collections", "futures 0.3.31", "google_ai", @@ -6581,6 +6625,9 @@ version = "0.1.0" dependencies = [ "anthropic", "anyhow", + "aws-config", + "aws-credential-types", + "bedrock", "client", "collections", "copilot", diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index e303324416db5..a13a23152e64c 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -3,19 +3,19 @@ mod models; use std::{str::FromStr}; use std::any::Any; use anyhow::{Context, Result}; -use aws_sdk_bedrockruntime::types::{ConverseStreamOutput}; +use aws_sdk_bedrockruntime::types::{ContentBlockDeltaEvent, ContentBlockStartEvent, ConverseStreamMetadataEvent, ConverseStreamOutput, Message, MessageStartEvent, MessageStopEvent}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, Stream, StreamExt}; use serde::{Deserialize, Serialize}; use thiserror::Error; -pub use aws_sdk_bedrockruntime as bedrock; -pub use bedrock::operation::converse_stream::ConverseStreamInput as StreamingRequest; -pub use bedrock::types::ContentBlock as RequestContent; -pub use bedrock::types::ConverseOutput as Response; -pub use bedrock::types::Message; -pub use bedrock::types::ConversationRole; -pub use bedrock::types::ResponseStream; -pub use bedrock::types::ConverseStreamOutput as BedrockEvent; +use aws_sdk_bedrockruntime as bedrock; +pub use aws_sdk_bedrockruntime as bedrock_client; +pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest; +pub use bedrock::types::ContentBlock as BedrockRequestContent; +use bedrock::types::ConverseOutput as Response; +pub use bedrock::types::Message as BedrockMessage; +pub use bedrock::types::ConversationRole as BedrockRole; +pub use bedrock::types::ResponseStream as BedrockResponseStream; //TODO: Re-export the Bedrock stuff // https://doc.rust-lang.org/rustdoc/write-documentation/re-exports.html @@ -123,3 +123,11 @@ pub enum BedrockError { SdkError(bedrock::Error), Other(anyhow::Error) } + +pub enum BedrockEvent { + ContentBlockDelta(ContentBlockDeltaEvent), + ContentBlockStart(ContentBlockStartEvent), + MessageStart(MessageStartEvent), + MessageStop(MessageStopEvent), + Metadata(ConverseStreamMetadataEvent), +} diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index f2054df183565..3caa88642dcb7 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -18,8 +18,6 @@ test-support = [] [dependencies] bedrock = { workspace = true, features = ["schemars"] } anthropic = { workspace = true, features = ["schemars"] } -aws-credential-types = { workspace = true, features = ["hardcoded-credentials"] } -aws-config = { workspace = true, features = ["behavior-version-latest"]} aws-sdk-bedrockruntime = { workspace = true, features = ["behavior-version-latest"] } anyhow.workspace = true base64.workspace = true diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 9242f80e6e16c..4d12597412db7 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -3,7 +3,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use strum::EnumIter; use ui::IconName; - +use crate::fake_provider::language_model_id; use crate::LanguageModelAvailability; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] @@ -12,6 +12,7 @@ pub enum CloudModel { Anthropic(anthropic::Model), OpenAi(open_ai::Model), Google(google_ai::Model), + Bedrock(bedrock::Model) } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)] @@ -32,6 +33,7 @@ impl CloudModel { Self::Anthropic(model) => model.id(), Self::OpenAi(model) => model.id(), Self::Google(model) => model.id(), + Self::Bedrock(model) => model.id(), } } @@ -40,6 +42,7 @@ impl CloudModel { Self::Anthropic(model) => model.display_name(), Self::OpenAi(model) => model.display_name(), Self::Google(model) => model.display_name(), + Self::Bedrock(model) => model.display_name(), } } @@ -55,6 +58,7 @@ impl CloudModel { Self::Anthropic(model) => model.max_token_count(), Self::OpenAi(model) => model.max_token_count(), Self::Google(model) => model.max_token_count(), + Self::Bedrock(model) => model.max_token_count(), } } @@ -91,6 +95,13 @@ impl CloudModel { LanguageModelAvailability::RequiresPlan(Plan::ZedPro) } }, + Self::Bedrock(model) => match model { + /* + TODO: Get guidance from the Zed team on what Pro means, since they're technically + all pay per use + */ + _ => LanguageModelAvailability::RequiresPlan(Plan::ZedPro) + }, } } } diff --git a/crates/language_model/src/model/mod.rs b/crates/language_model/src/model/mod.rs index 7b5ac88dea1b7..6ee246e7871d5 100644 --- a/crates/language_model/src/model/mod.rs +++ b/crates/language_model/src/model/mod.rs @@ -4,3 +4,4 @@ pub use anthropic::Model as AnthropicModel; pub use cloud_model::*; pub use ollama::Model as OllamaModel; pub use open_ai::Model as OpenAiModel; +pub use bedrock::Model as BedrockModel; diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 00d948bd2d4a7..b836d70cd263e 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -14,6 +14,9 @@ path = "src/language_models.rs" [dependencies] anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true +aws-credential-types = { workspace = true, features = ["hardcoded-credentials"] } +aws-config = { workspace = true, features = ["behavior-version-latest"]} +bedrock.workspace = true client.workspace = true collections.workspace = true copilot = { workspace = true, features = ["schemars"] } diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index b22819c0066a0..119764f1a1362 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -1,34 +1,37 @@ -use crate::{ - settings::AllLanguageModelSettings, LanguageModel, LanguageModelCacheConfiguration, - LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, -}; -use crate::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; +use crate::AllLanguageModelSettings; +use bedrock::{BedrockError, BedrockEvent, bedrock_client, Model}; use anyhow::{anyhow, Context as _, Result}; -use aws_config as config; -use aws_config::meta::credentials::CredentialsProviderChain; -use aws_config::Region; -use aws_credential_types::Credentials; -use aws_sdk_bedrockruntime as bedrock_client; -use aws_sdk_bedrockruntime::Config; use collections::{BTreeMap, HashMap}; use editor::{Editor, EditorElement, EditorStyle}; use futures::Stream; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _}; -use gpui::{AnyView, AppContext, AsyncAppContext, FontStyle, Model, ModelContext, Subscription, Task, TextStyle, View, WhiteSpace}; +use gpui::{ + AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle, + View, WhiteSpace, +}; +use http_client::HttpClient; +use language_model::{ + LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, +}; +use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr; -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; +use aws_config::Region; +use aws_credential_types::Credentials; use serde_json::Value; use strum::IntoEnumIterator; -use bedrock::{BedrockError, ContentDelta, Event, ResponseContent}; +use bedrock::bedrock_client::Config; use theme::ThemeSettings; use ui::{prelude::*, Icon, IconName, Tooltip}; use util::{maybe, ResultExt}; + const PROVIDER_ID : &str = "amazon-bedrock"; const PROVIDER_NAME : &str = "Amazon Bedrock"; @@ -211,8 +214,8 @@ impl BedrockLanguageModelProvider { struct BedrockModel { id: LanguageModelId, - model: bedrock::Model, - state: Model, + model: Model, + state: gpui::Model, request_limiter: RateLimiter, } @@ -221,8 +224,8 @@ impl BedrockModel { &self, request: bedrock::Request, cx: &AsyncAppContext - ) -> BoxFuture<'static, Result>>> { - + ) -> BoxFuture<'static, Result>>> { + todo!() } } @@ -287,72 +290,21 @@ impl LanguageModel for BedrockModel { } fn cache_configuration(&self) -> Option { - self.model - .cache_configuration() - .map(|config| LanguageModelCacheConfiguration { - max_cache_anchors: config.max_cache_anchors, - should_speculate: config.should_speculate, - min_total_token: config.min_total_token, - }) + unimplemented!() } } -fn get_bedrock_tokens(request: LanguageModelRequest, cx: &AppContext) -> BoxFuture<'static, Result> { - cx.background_executor() - .spawn(async move { - let messages = request.messages; - let mut tokens_from_images = 0; - let mut string_messages = Vec::with_capacity(messages.len()); - - for message in messages { - use crate::MessageContent; - - let mut string_contents = String::new(); - - for content in message.content { - match content { - MessageContent::Text(text) => { - string_contents.push_str(&text); - } - MessageContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - MessageContent::ToolUse(_tool_use) => { - unimplemented!(); - } - MessageContent::ToolResult(tool_result) => { - unimplemented!(); - } - } - } - - if !string_contents.is_empty() { - string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(string_contents), - name: None, - function_call: None, - }); - } - } - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) - .map(|tokens| tokens + tokens_from_images) - }) - .boxed() +// TODO: just call the ConverseOutput.usage() method: +// https://docs.rs/aws-sdk-bedrockruntime/latest/aws_sdk_bedrockruntime/operation/converse/struct.ConverseOutput.html#method.output +pub fn get_bedrock_tokens(request: LanguageModelRequest, cx: &AppContext) -> BoxFuture<'static, Result> { + todo!() } pub fn map_to_language_model_completion_events( - events: Pin>>>, + events: Pin>>>, ) -> impl Stream> { struct State { - events: Pin>>> + events: Pin>>> } futures::stream::unfold( @@ -363,42 +315,11 @@ pub fn map_to_language_model_completion_events( while let Some(event) = state.events.next().await { match event { Ok(event) => match event { - Event::ContentBlockStart { - index, - content_block, - } => match content_block { - ResponseContent::Text { text } => { - return Some(( - Some(Ok(LanguageModelCompletionEvent::Text(text))), - state, - )); - } - }, - Event::ContentBlockDelta { index, delta } => match delta { - ContentDelta::TextDelta { text } => { - return Some(( - Some(Ok(LanguageModelCompletionEvent::Text(text))), - state, - )); - } - _ => {} - }, - Event::MessageDelta { delta, .. } => { - if let Some(stop_reason) = delta.stop_reason.as_deref() { - let stop_reason = match stop_reason { - "end_turn" => StopReason::EndTurn, - "max_tokens" => StopReason::MaxTokens, - "tool_use" => StopReason::ToolUse, - _ => StopReason::EndTurn, - }; - - return Some(( - Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))), - state, - )); - } - } - _ => {} + BedrockEvent::ContentBlockDelta(_) => {} + BedrockEvent::ContentBlockStart(_) => {} + BedrockEvent::MessageStart(_) => {} + BedrockEvent::MessageStop(_) => {} + BedrockEvent::Metadata(_) => {} }, Err(err) => { return Some((Some(Err(anyhow!(err))), state)); @@ -441,13 +362,6 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { name: model.name.clone(), display_name: model.display_name.clone(), max_tokens: model.max_tokens, - cache_configuration: model.cache_configuration.as_ref().map(|config| { - bedrock::BedrockModelCacheConfiguration { - max_cache_anchors: config.max_cache_anchors, - should_speculate: config.should_speculate, - min_total_token: config.min_total_token, - } - }), max_output_tokens: model.max_output_tokens, default_temperature: model.default_temperature, }, @@ -488,7 +402,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { impl LanguageModelProviderState for BedrockLanguageModelProvider { type ObservableEntity = State; - fn observable_entity(&self) -> Option> { + fn observable_entity(&self) -> Option> { Some(self.state.clone()) } } @@ -497,7 +411,7 @@ struct ConfigurationView { access_key_id_editor: View, secret_access_key_editor: View, region_editor: View, - state: Model, + state: gpui::Model, load_credentials_task: Option>, } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index f54e8c8d19b40..82ec642c3c268 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -40,12 +40,13 @@ use std::{ }; use strum::IntoEnumIterator; use thiserror::Error; +use tiktoken_rs::model; use ui::{prelude::*, TintColor}; use crate::provider::anthropic::map_to_language_model_completion_events; use crate::AllLanguageModelSettings; - use super::anthropic::count_anthropic_tokens; +use super::bedrock::get_bedrock_tokens; pub const PROVIDER_NAME: &str = "Zed"; @@ -72,6 +73,7 @@ pub enum AvailableProvider { Anthropic, OpenAi, Google, + Bedrock, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] @@ -287,6 +289,11 @@ impl LanguageModelProvider for CloudLanguageModelProvider { models.insert(model.id().to_string(), CloudModel::Google(model)); } } + for model in bedrock::Model::iter() { + if !matches!(model, bedrock::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::Bedrock(model)); + } + } } else { models.insert( anthropic::Model::Claude3_5Sonnet.id().to_string(), @@ -336,6 +343,13 @@ impl LanguageModelProvider for CloudLanguageModelProvider { display_name: model.display_name.clone(), max_tokens: model.max_tokens, }), + AvailableProvider::Bedrock => CloudModel::Bedrock(bedrock::Model::Custom { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + max_output_tokens: model.max_output_tokens, + default_temperature: model.default_temperature, + }), }; models.insert(model.id().to_string(), model.clone()); } @@ -568,6 +582,7 @@ impl LanguageModel for CloudLanguageModel { }) } CloudModel::OpenAi(_) | CloudModel::Google(_) => None, + CloudModel::Bedrock(_) => unimplemented!() } } @@ -597,6 +612,7 @@ impl LanguageModel for CloudLanguageModel { } .boxed() } + CloudModel::Bedrock(_) => get_bedrock_tokens(request, cx), } } @@ -689,6 +705,9 @@ impl LanguageModel for CloudLanguageModel { } .boxed() } + CloudModel::Bedrock(_) => { + todo!() + } } } @@ -790,6 +809,7 @@ impl LanguageModel for CloudLanguageModel { CloudModel::Google(_) => { future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed() } + CloudModel::Bedrock(_) => { todo!() } } } } From 05d7e2af7e6769abb4d38a03aea85b1be2a97cde Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Mon, 25 Nov 2024 16:43:45 -0500 Subject: [PATCH 09/13] Removed bedrock from the cloudmodel --- crates/bedrock/src/bedrock.rs | 46 ++++------- crates/collab/src/llm.rs | 3 +- .../language_model/src/model/cloud_model.rs | 13 +--- .../language_models/src/provider/bedrock.rs | 78 +++++++++++++++---- crates/language_models/src/provider/cloud.rs | 23 +----- 5 files changed, 81 insertions(+), 82 deletions(-) diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index a13a23152e64c..925e60cd931e7 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -2,14 +2,17 @@ mod models; use std::{str::FromStr}; use std::any::Any; -use anyhow::{Context, Result}; -use aws_sdk_bedrockruntime::types::{ContentBlockDeltaEvent, ContentBlockStartEvent, ConverseStreamMetadataEvent, ConverseStreamOutput, Message, MessageStartEvent, MessageStopEvent}; -use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, Stream, StreamExt}; +use std::io::BufReader; +use anyhow::{anyhow, Context, Result}; +use aws_sdk_bedrockruntime::types::{ContentBlockDeltaEvent, ContentBlockStart, ContentBlockStartEvent, ConverseStreamMetadataEvent, ConverseStreamOutput, Message, MessageStartEvent, MessageStopEvent}; +use futures::{stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, Stream, StreamExt}; use serde::{Deserialize, Serialize}; use thiserror::Error; use aws_sdk_bedrockruntime as bedrock; pub use aws_sdk_bedrockruntime as bedrock_client; +use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError; +pub use bedrock::types::ConverseStreamOutput as BedrockStreamingResponse; pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest; pub use bedrock::types::ContentBlock as BedrockRequestContent; use bedrock::types::ConverseOutput as Response; @@ -44,7 +47,7 @@ pub async fn complete( pub async fn stream_completion( client: &bedrock::Client, request: Request, -) -> Result>, BedrockError> { // There is no generic Bedrock event Type? +) -> Result>, BedrockError> { // There is no generic Bedrock event Type? let response = bedrock::Client::converse_stream(client) .model_id(request.model) @@ -52,29 +55,15 @@ pub async fn stream_completion( let mut stream = match response { - Ok(output) => Ok(output.stream), + Ok(output) => Ok(output.stream()), Err(e) => Err( - BedrockError::SdkError(e.as_service_error().unwrap()) + BedrockError::ClientError(anyhow!(e)) ), }?; - loop { - let token = stream.recv().await; - match token { - Ok(Some(text)) => { - let next = get_converse_output_text(text)?; - print!("{}", next); - Ok(()) - } - Ok(None) => break, - Err(e) => Err(e - .as_service_error() - .map(BedrockConverseStreamError::from) - .unwrap_or(BedrockConverseStreamError( - "Unknown error receiving stream".into(), - ))), - }? - } + + + Ok() } fn get_converse_output_text( @@ -120,14 +109,7 @@ pub struct Metadata { #[derive(Error, Debug)] pub enum BedrockError { - SdkError(bedrock::Error), + ClientError(anyhow::Error), + ExtensionError(anyhow::Error), Other(anyhow::Error) } - -pub enum BedrockEvent { - ContentBlockDelta(ContentBlockDeltaEvent), - ContentBlockStart(ContentBlockStartEvent), - MessageStart(MessageStartEvent), - MessageStop(MessageStopEvent), - Metadata(ConverseStreamMetadataEvent), -} diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index dfd658b3179e4..114fd19bc77f6 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -407,8 +407,7 @@ async fn perform_completion( .boxed() } LanguageModelProvider::Bedrock => { - // TODO: implement this - unimplemented!() + Err("Unimplemented").boxed() } }; diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 4d12597412db7..076943d7b6f17 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -12,7 +12,6 @@ pub enum CloudModel { Anthropic(anthropic::Model), OpenAi(open_ai::Model), Google(google_ai::Model), - Bedrock(bedrock::Model) } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)] @@ -33,7 +32,6 @@ impl CloudModel { Self::Anthropic(model) => model.id(), Self::OpenAi(model) => model.id(), Self::Google(model) => model.id(), - Self::Bedrock(model) => model.id(), } } @@ -42,7 +40,6 @@ impl CloudModel { Self::Anthropic(model) => model.display_name(), Self::OpenAi(model) => model.display_name(), Self::Google(model) => model.display_name(), - Self::Bedrock(model) => model.display_name(), } } @@ -58,7 +55,6 @@ impl CloudModel { Self::Anthropic(model) => model.max_token_count(), Self::OpenAi(model) => model.max_token_count(), Self::Google(model) => model.max_token_count(), - Self::Bedrock(model) => model.max_token_count(), } } @@ -94,14 +90,7 @@ impl CloudModel { | google_ai::Model::Custom { .. } => { LanguageModelAvailability::RequiresPlan(Plan::ZedPro) } - }, - Self::Bedrock(model) => match model { - /* - TODO: Get guidance from the Zed team on what Pro means, since they're technically - all pay per use - */ - _ => LanguageModelAvailability::RequiresPlan(Plan::ZedPro) - }, + } } } } diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 119764f1a1362..82fd3e2eeace1 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -1,5 +1,5 @@ use crate::AllLanguageModelSettings; -use bedrock::{BedrockError, BedrockEvent, bedrock_client, Model}; +use bedrock::{BedrockError, BedrockEvent, bedrock_client, Model, BedrockStreamingResponse}; use anyhow::{anyhow, Context as _, Result}; use collections::{BTreeMap, HashMap}; use editor::{Editor, EditorElement, EditorStyle}; @@ -27,6 +27,7 @@ use aws_credential_types::Credentials; use serde_json::Value; use strum::IntoEnumIterator; use bedrock::bedrock_client::Config; +use bedrock::bedrock_client::types::{ContentBlockDelta, ContentBlockStart, ContentBlockStartEvent, ConverseStreamOutput}; use theme::ThemeSettings; use ui::{prelude::*, Icon, IconName, Tooltip}; use util::{maybe, ResultExt}; @@ -215,6 +216,7 @@ impl BedrockLanguageModelProvider { struct BedrockModel { id: LanguageModelId, model: Model, + runtime_client: bedrock_client::Client, state: gpui::Model, request_limiter: RateLimiter, } @@ -224,8 +226,14 @@ impl BedrockModel { &self, request: bedrock::Request, cx: &AsyncAppContext - ) -> BoxFuture<'static, Result>>> { - todo!() + ) -> BoxFuture<'static, Result>>> { + let bedrock_client = self.runtime_client.clone(); + + async move { + let request = bedrock::stream_completion(&bedrock_client, request); + + request.await.context("failed to perform stream completion") + }.boxed() } } @@ -278,11 +286,8 @@ impl LanguageModel for BedrockModel { ); let request = self.stream_completion(request, cx); - let future = self.request_limiter.stream(async move { - let response = request.await.map_err(|err| anyhow!(err))?; - Ok(map_to_language_model_completion_events(response)) - }); - async move { Ok(future.await?.boxed()) }.boxed() + let future + async move { Ok(request.await?.boxed()) }.boxed() } fn use_any_tool(&self, request: LanguageModelRequest, name: String, description: String, schema: Value, cx: &AsyncAppContext) -> BoxFuture<'static, Result>>> { @@ -301,10 +306,10 @@ pub fn get_bedrock_tokens(request: LanguageModelRequest, cx: &AppContext) -> Box } pub fn map_to_language_model_completion_events( - events: Pin>>>, + events: Pin>>>, ) -> impl Stream> { struct State { - events: Pin>>> + events: Pin>>> } futures::stream::unfold( @@ -315,11 +320,54 @@ pub fn map_to_language_model_completion_events( while let Some(event) = state.events.next().await { match event { Ok(event) => match event { - BedrockEvent::ContentBlockDelta(_) => {} - BedrockEvent::ContentBlockStart(_) => {} - BedrockEvent::MessageStart(_) => {} - BedrockEvent::MessageStop(_) => {} - BedrockEvent::Metadata(_) => {} + ConverseStreamOutput::ContentBlockDelta(cb_delta) => { + match cb_delta.delta { + Some(delta) => { + match delta { + ContentBlockDelta::Text(text_out) => { + return Some(( + Some(Ok(LanguageModelCompletionEvent::Text(text_out))), + state + )); + } + ContentBlockDelta::ToolUse(_) => unimplemented!("The Bedrock provider has not implemented tool use yet"), + ContentBlockDelta::Unknown => { + return Some((Some(Err(anyhow!("Unknown Delta type, Update the Bedrock SDK"))), state)) + } + _ => {} + } + } + None => return Some((None, state)) + } + } + ConverseStreamOutput::ContentBlockStart(cb_start) => { + match cb_start.start { + Some(start) => { + match start { + ContentBlockStart::ToolUse(_) => unimplemented!("The Bedrock provider has not implemented tool use yet"), + ContentBlockStart::Unknown => { + return Some((Some(Err(anyhow!("Unknown Delta type, Update the Bedrock SDK"))), state)) + } + _ => {} + } + } + None => {} + } + } + ConverseStreamOutput::ContentBlockStop(cb_stop) => { + unimplemented!("The Bedrock provider has not implemented tool use yet, this event will only be received on tool use") + } + ConverseStreamOutput::MessageStart(message) => {} + ConverseStreamOutput::MessageStop(_) => { + // This contains information if response generation has stopped for any reason + } + ConverseStreamOutput::Metadata(metadata) => { + // This contains stream metadata including token usage, metrics and traces for guardrail behaviour + } + ConverseStreamOutput::Unknown => { + return Some((Some(Err(anyhow!("Unknown event type, Update the Bedrock SDK"))), state)) + } + _ => {} }, Err(err) => { return Some((Some(Err(anyhow!(err))), state)); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 82ec642c3c268..e354cf67a3636 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -73,7 +73,6 @@ pub enum AvailableProvider { Anthropic, OpenAi, Google, - Bedrock, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] @@ -289,11 +288,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider { models.insert(model.id().to_string(), CloudModel::Google(model)); } } - for model in bedrock::Model::iter() { - if !matches!(model, bedrock::Model::Custom { .. }) { - models.insert(model.id().to_string(), CloudModel::Bedrock(model)); - } - } } else { models.insert( anthropic::Model::Claude3_5Sonnet.id().to_string(), @@ -342,14 +336,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { name: model.name.clone(), display_name: model.display_name.clone(), max_tokens: model.max_tokens, - }), - AvailableProvider::Bedrock => CloudModel::Bedrock(bedrock::Model::Custom { - name: model.name.clone(), - display_name: model.display_name.clone(), - max_tokens: model.max_tokens, - max_output_tokens: model.max_output_tokens, - default_temperature: model.default_temperature, - }), + }) }; models.insert(model.id().to_string(), model.clone()); } @@ -581,8 +568,7 @@ impl LanguageModel for CloudLanguageModel { min_total_token: cache.min_total_token, }) } - CloudModel::OpenAi(_) | CloudModel::Google(_) => None, - CloudModel::Bedrock(_) => unimplemented!() + CloudModel::OpenAi(_) | CloudModel::Google(_) => None } } @@ -612,7 +598,6 @@ impl LanguageModel for CloudLanguageModel { } .boxed() } - CloudModel::Bedrock(_) => get_bedrock_tokens(request, cx), } } @@ -705,9 +690,6 @@ impl LanguageModel for CloudLanguageModel { } .boxed() } - CloudModel::Bedrock(_) => { - todo!() - } } } @@ -809,7 +791,6 @@ impl LanguageModel for CloudLanguageModel { CloudModel::Google(_) => { future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed() } - CloudModel::Bedrock(_) => { todo!() } } } } From 5c4a53310d4a42b377a37f8a6cafdb3b932ce93b Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Tue, 26 Nov 2024 18:06:30 -0500 Subject: [PATCH 10/13] Fixed a whole host of mismatched types, Arc will be the death of me --- crates/bedrock/src/bedrock.rs | 77 +++-- .../language_model/src/model/cloud_model.rs | 1 - crates/language_model/src/request.rs | 30 +- .../language_models/src/provider/bedrock.rs | 276 ++++++++++-------- 4 files changed, 201 insertions(+), 183 deletions(-) diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index 925e60cd931e7..58601340bde14 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -1,25 +1,25 @@ mod models; -use std::{str::FromStr}; -use std::any::Any; -use std::io::BufReader; use anyhow::{anyhow, Context, Result}; -use aws_sdk_bedrockruntime::types::{ContentBlockDeltaEvent, ContentBlockStart, ContentBlockStartEvent, ConverseStreamMetadataEvent, ConverseStreamOutput, Message, MessageStartEvent, MessageStopEvent}; -use futures::{stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, Stream, StreamExt}; use serde::{Deserialize, Serialize}; +use std::any::Any; +use std::str::FromStr; use thiserror::Error; use aws_sdk_bedrockruntime as bedrock; pub use aws_sdk_bedrockruntime as bedrock_client; +use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, Message}; use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError; -pub use bedrock::types::ConverseStreamOutput as BedrockStreamingResponse; pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest; pub use bedrock::types::ContentBlock as BedrockRequestContent; +pub use bedrock::types::ConversationRole as BedrockRole; use bedrock::types::ConverseOutput as Response; +pub use bedrock::types::ConverseStreamOutput as BedrockStreamingResponse; pub use bedrock::types::Message as BedrockMessage; -pub use bedrock::types::ConversationRole as BedrockRole; pub use bedrock::types::ResponseStream as BedrockResponseStream; - +use futures::stream::BoxStream; +use futures::StreamExt; +use strum::Display; //TODO: Re-export the Bedrock stuff // https://doc.rust-lang.org/rustdoc/write-documentation/re-exports.html @@ -47,58 +47,55 @@ pub async fn complete( pub async fn stream_completion( client: &bedrock::Client, request: Request, -) -> Result>, BedrockError> { // There is no generic Bedrock event Type? +) -> Result, BedrockError> { let response = bedrock::Client::converse_stream(client) .model_id(request.model) .set_messages(request.messages.into()).send().await; - let mut stream = match response { - Ok(output) => Ok(output.stream()), + match response { + Ok(mut output) => { + match output.stream.recv().await { + Ok(resp) => { + match resp { + None => { + Ok(None) + } + Some(output) => { + Ok(Some(output)) + } + } + } + Err(e) => { + Err(BedrockError::ClientError(anyhow!("Failed to receive response from Bedrock"))) + } + } + }, Err(e) => Err( BedrockError::ClientError(anyhow!(e)) ), - }?; - - - - Ok() + } } -fn get_converse_output_text( - output: ConverseStreamOutput, -) -> Result { - Ok(match output { - ConverseStreamOutput::ContentBlockDelta(c) => { - match c.delta() { - Some(delta) => delta.as_text().cloned().unwrap_or_else(|_| "".into()), - None => "".into(), - } - } - _ => { - String::from("") - } - }) -} //TODO: A LOT of these types need to re-export the Bedrock types instead of making custom ones -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] pub struct Request { pub model: String, pub max_tokens: u32, - pub messages: Vec, - #[serde(default, skip_serializing_if = "Option::is_none")] + pub messages: Vec, + // #[serde(default, skip_serializing_if = "Option::is_none")] pub system: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] + // #[serde(default, skip_serializing_if = "Option::is_none")] pub metadata: Option, - #[serde(default, skip_serializing_if = "Vec::is_empty")] + // #[serde(default, skip_serializing_if = "Vec::is_empty")] pub stop_sequences: Vec, - #[serde(default, skip_serializing_if = "Option::is_none")] + // #[serde(default, skip_serializing_if = "Option::is_none")] pub temperature: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] + // #[serde(default, skip_serializing_if = "Option::is_none")] pub top_k: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] + // #[serde(default, skip_serializing_if = "Option::is_none")] pub top_p: Option, } @@ -107,7 +104,7 @@ pub struct Metadata { pub user_id: Option, } -#[derive(Error, Debug)] +#[derive(Error, Debug, Display)] pub enum BedrockError { ClientError(anyhow::Error), ExtensionError(anyhow::Error), diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 076943d7b6f17..5ac26d0c7b8a5 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -3,7 +3,6 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use strum::EnumIter; use ui::IconName; -use crate::fake_provider::language_model_id; use crate::LanguageModelAvailability; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 1cadbebc3327f..4f5bf1956fb46 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -1,11 +1,12 @@ use std::io::{Cursor, Write}; -use aws_sdk_bedrockruntime::types::ImageBlock; +use aws_sdk_bedrockruntime::types::{ContentBlock, ImageBlock}; use crate::role::Role; use crate::LanguageModelToolUse; use base64::write::EncoderWriter; use gpui::{point, size, AppContext, DevicePixels, Image, ImageFormat, ObjectFit, RenderImage, Size, Task}; use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder}; use serde::{Deserialize, Serialize}; +use bedrock::BedrockMessage; use ui::{px, SharedString}; use util::ResultExt; @@ -417,7 +418,7 @@ impl LanguageModelRequest { default_temperature: f32, max_output_tokens: u32 ) -> bedrock::Request { - let mut new_messages: Vec = Vec::new(); + let mut new_messages: Vec = Vec::new(); let mut system_message = String::new(); for message in self.messages { @@ -427,20 +428,13 @@ impl LanguageModelRequest { match message.role { Role::User | Role::Assistant => { - let cache_control = if message.cache { - Some(bedrock::CacheControl { - cache_type: bedrock::CacheControlType::Ephemeral, - }) - } else { - None - }; - let bedrock_message_content: Vec = message + let bedrock_message_content: Vec = message .content .into_iter() .filter_map(|content| match content { MessageContent::Text(text) => { if !text.is_empty() { - Some(bedrock::RequestContent::Text(text)) + Some(ContentBlock::Text(text)) } else { None } @@ -458,8 +452,8 @@ impl LanguageModelRequest { }) .collect(); let bedrock_role = match message.role { - Role::User => bedrock::ConversationRole::User, - Role::Assistant => bedrock::ConversationRole::Assistant, + Role::User => bedrock::BedrockRole::User, + Role::Assistant => bedrock::BedrockRole::Assistant, Role::System => unreachable!("System role should never occur here") }; if let Some(last_message) = new_messages.last_mut() { @@ -468,10 +462,12 @@ impl LanguageModelRequest { continue; } } - new_messages.push(bedrock::Message { - role: bedrock_role, - content: bedrock_message_content - }); + new_messages.push( + BedrockMessage::builder() + .role(bedrock_role) + .set_content(Some(bedrock_message_content)) + .build().unwrap() // unsafe unwrap, but it should be fine + ); } Role::System => { if !system_message.is_empty() { diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 82fd3e2eeace1..7a633d72e57f4 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -1,6 +1,12 @@ use crate::AllLanguageModelSettings; -use bedrock::{BedrockError, BedrockEvent, bedrock_client, Model, BedrockStreamingResponse}; use anyhow::{anyhow, Context as _, Result}; +use aws_config::Region; +use aws_credential_types::Credentials; +use bedrock::bedrock_client::types::{ + ContentBlockDelta, ContentBlockStart, ContentBlockStartEvent, ConverseStreamOutput, +}; +use bedrock::bedrock_client::Config; +use bedrock::{bedrock_client, BedrockError, BedrockStreamingResponse, Model}; use collections::{BTreeMap, HashMap}; use editor::{Editor, EditorElement, EditorStyle}; use futures::Stream; @@ -18,29 +24,25 @@ use language_model::{ use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use serde_json::Value; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; -use aws_config::Region; -use aws_credential_types::Credentials; -use serde_json::Value; use strum::IntoEnumIterator; -use bedrock::bedrock_client::Config; -use bedrock::bedrock_client::types::{ContentBlockDelta, ContentBlockStart, ContentBlockStartEvent, ConverseStreamOutput}; +use bedrock::bedrock_client::config::{IntoShared, SharedHttpClient}; use theme::ThemeSettings; use ui::{prelude::*, Icon, IconName, Tooltip}; use util::{maybe, ResultExt}; - -const PROVIDER_ID : &str = "amazon-bedrock"; -const PROVIDER_NAME : &str = "Amazon Bedrock"; +const PROVIDER_ID: &str = "amazon-bedrock"; +const PROVIDER_NAME: &str = "Amazon Bedrock"; #[derive(Default, Clone, Debug, PartialEq)] pub struct AmazonBedrockSettings { pub region: Option, pub credentials: Option, - pub available_models: Vec + pub available_models: Vec, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] @@ -66,24 +68,35 @@ const ZED_BEDROCK_SK: &str = "ZED_SECRET_ACCESS_KEY"; const ZED_BEDROCK_REGION: &str = "ZED_AWS_REGION"; pub struct State { - credentials: Option, + credentials: Option, credentials_from_env: bool, region: Option, - _subscription: Subscription + _subscription: Subscription, } - pub struct BedrockLanguageModelProvider { runtime_client: bedrock_client::Client, - state: gpui::Model + state: gpui::Model, } impl State { fn reset_credentials(&self, cx: &mut ModelContext) -> Task> { - let delete_aa_id= - cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).bedrock.credentials.clone().unwrap().access_key_id); - let delete_sk: Task> = - cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).bedrock.credentials.clone().unwrap().secret_access_key); + let delete_aa_id = cx.delete_credentials( + &AllLanguageModelSettings::get_global(cx) + .bedrock + .credentials + .clone() + .unwrap() + .access_key_id, + ); + let delete_sk: Task> = cx.delete_credentials( + &AllLanguageModelSettings::get_global(cx) + .bedrock + .credentials + .clone() + .unwrap() + .secret_access_key, + ); cx.spawn(|this, mut cx| async move { delete_aa_id.await.ok(); delete_sk.await.ok(); @@ -95,32 +108,34 @@ impl State { }) } - fn set_credentials(&mut self, access_key_id: String, secret_key: String, region: String, cx: &mut ModelContext) -> Task> { + fn set_credentials( + &mut self, + access_key_id: String, + secret_key: String, + region: String, + cx: &mut ModelContext, + ) -> Task> { let write_aa_id = cx.write_credentials( ZED_BEDROCK_AAID, // TODO: GET THIS REVIEWED, MAKE SURE IT DOESN'T BREAK STUFF LONG TERM "Bearer", - access_key_id.as_bytes() + access_key_id.as_bytes(), ); let write_sk = cx.write_credentials( - ZED_BEDROCK_SK, // TODO: GET THIS REVIEWED, MAKE SURE IT DOESN'T BREAK STUFF LONG TERM - "Bearer", - secret_key.as_bytes() - ); - let write_region = cx.write_credentials( - ZED_BEDROCK_REGION, + ZED_BEDROCK_SK, // TODO: GET THIS REVIEWED, MAKE SURE IT DOESN'T BREAK STUFF LONG TERM "Bearer", - region.as_bytes() + secret_key.as_bytes(), ); + let write_region = cx.write_credentials(ZED_BEDROCK_REGION, "Bearer", region.as_bytes()); cx.spawn(|this, mut cx| async move { write_aa_id.await?; write_sk.await?; write_region.await?; this.update(&mut cx, |this, cx| { - this.credentials = Some(AmazonBedrockCredentials{ + this.credentials = Some(AmazonBedrockCredentials { access_key_id, secret_access_key: secret_key, - session_token: None + session_token: None, }); this.region = Some(region); cx.notify(); @@ -138,64 +153,75 @@ impl State { Task::ready(Ok(())) } else { cx.spawn(|this, mut cx| async move { - let (aa_id, sk, region, from_env) = if let (Ok(aa_id), Ok(sk), Ok(region)) - = (std::env::var(ZED_BEDROCK_AAID), std::env::var(ZED_BEDROCK_SK), std::env::var(ZED_BEDROCK_REGION)) - { + let (aa_id, sk, region, from_env) = if let (Ok(aa_id), Ok(sk), Ok(region)) = ( + std::env::var(ZED_BEDROCK_AAID), + std::env::var(ZED_BEDROCK_SK), + std::env::var(ZED_BEDROCK_REGION), + ) { (aa_id, sk, region, true) } else { let (_, aa_id) = cx - .update(| cx | cx.read_credentials(ZED_BEDROCK_AAID))? + .update(|cx| cx.read_credentials(ZED_BEDROCK_AAID))? .await? .ok_or_else(|| anyhow!("Access key ID not found"))?; let (_, sk) = cx - .update(| cx | cx.read_credentials(ZED_BEDROCK_SK))? + .update(|cx| cx.read_credentials(ZED_BEDROCK_SK))? .await? .ok_or_else(|| anyhow!("Secret access key not found"))?; let (_, region) = cx - .update(| cx | cx.read_credentials(ZED_BEDROCK_REGION))? + .update(|cx| cx.read_credentials(ZED_BEDROCK_REGION))? .await? .ok_or_else(|| anyhow!("Region not found"))?; - - (String::from_utf8(aa_id)?, String::from_utf8(sk)?, String::from_utf8(region)?, false) + ( + String::from_utf8(aa_id)?, + String::from_utf8(sk)?, + String::from_utf8(region)?, + false, + ) }; this.update(&mut cx, |this, cx| { this.credentials_from_env = from_env; - this.credentials = Some(AmazonBedrockCredentials{ + this.credentials = Some(AmazonBedrockCredentials { access_key_id: aa_id, secret_access_key: sk, - session_token: None + session_token: None, }); this.region = Some(region); cx.notify(); }) }) } - } } impl BedrockLanguageModelProvider { - pub fn new(cx: &mut AppContext) -> Self { - + pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { let state = cx.new_model(|cx| State { credentials: None, region: Some(String::from("us-east-1")), credentials_from_env: false, _subscription: cx.observe_global::(|_, cx| { cx.notify(); - }) + }), }); - let region_def: String = state.read(cx).region.clone() - .or_else(|| {Some(String::from("us-east-1"))}) + let region_def: String = state + .read(cx) + .region + .clone() + .or_else(|| Some(String::from("us-east-1"))) .unwrap(); - let creds_clone = &state.read(cx).credentials.clone() - .or_else(|| { Some(AmazonBedrockCredentials::default()) }) + let creds_clone = &state + .read(cx) + .credentials + .clone() + .or_else(|| Some(AmazonBedrockCredentials::default())) .unwrap(); let client_config = Config::builder() + .http_client(SharedHttpClient::new(http_client.as_ref())) .region(Region::new(region_def)) .credentials_provider(Credentials::from_keys( &creds_clone.clone().access_key_id, @@ -208,7 +234,7 @@ impl BedrockLanguageModelProvider { Self { runtime_client, - state + state, } } } @@ -225,15 +251,16 @@ impl BedrockModel { fn stream_completion( &self, request: bedrock::Request, - cx: &AsyncAppContext - ) -> BoxFuture<'static, Result>>> { - let bedrock_client = self.runtime_client.clone(); - + _: &AsyncAppContext + ) -> BoxFuture< + 'static, + Result>>, + > { async move { - let request = bedrock::stream_completion(&bedrock_client, request); - + let request = bedrock::stream_completion(&self.runtime_client, request); request.await.context("failed to perform stream completion") - }.boxed() + } + .boxed() } } @@ -286,11 +313,21 @@ impl LanguageModel for BedrockModel { ); let request = self.stream_completion(request, cx); - let future - async move { Ok(request.await?.boxed()) }.boxed() + let future = self.request_limiter.stream(async move { + let response = request.await.map_err(|err| anyhow!(err))?; + Ok(map_to_language_model_completion_events(response)) + }); + async move { Ok(future.await?.boxed()) }.boxed() } - fn use_any_tool(&self, request: LanguageModelRequest, name: String, description: String, schema: Value, cx: &AsyncAppContext) -> BoxFuture<'static, Result>>> { + fn use_any_tool( + &self, + request: LanguageModelRequest, + name: String, + description: String, + schema: Value, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { unimplemented!(); } @@ -301,7 +338,10 @@ impl LanguageModel for BedrockModel { // TODO: just call the ConverseOutput.usage() method: // https://docs.rs/aws-sdk-bedrockruntime/latest/aws_sdk_bedrockruntime/operation/converse/struct.ConverseOutput.html#method.output -pub fn get_bedrock_tokens(request: LanguageModelRequest, cx: &AppContext) -> BoxFuture<'static, Result> { +pub fn get_bedrock_tokens( + request: LanguageModelRequest, + cx: &AppContext, +) -> BoxFuture<'static, Result> { todo!() } @@ -309,7 +349,7 @@ pub fn map_to_language_model_completion_events( events: Pin>>>, ) -> impl Stream> { struct State { - events: Pin>>> + events: Pin>>>, } futures::stream::unfold( @@ -321,58 +361,48 @@ pub fn map_to_language_model_completion_events( match event { Ok(event) => match event { ConverseStreamOutput::ContentBlockDelta(cb_delta) => { - match cb_delta.delta { - Some(delta) => { - match delta { - ContentBlockDelta::Text(text_out) => { - return Some(( - Some(Ok(LanguageModelCompletionEvent::Text(text_out))), - state - )); - } - ContentBlockDelta::ToolUse(_) => unimplemented!("The Bedrock provider has not implemented tool use yet"), - ContentBlockDelta::Unknown => { - return Some((Some(Err(anyhow!("Unknown Delta type, Update the Bedrock SDK"))), state)) - } - _ => {} - } - } - None => return Some((None, state)) + if let Some(ContentBlockDelta::Text(text_out)) = cb_delta.delta { + return Some(( + Some(Ok(LanguageModelCompletionEvent::Text(text_out))), + state, + )); + } else if let Some(ContentBlockDelta::ToolUse(_)) = cb_delta.delta { + return Some(( + Some(Err(anyhow!("The Bedrock provider has not implemented tool use yet"))), + state, + )); + } else if cb_delta.delta.is_none() { + return Some((None, state)); } } ConverseStreamOutput::ContentBlockStart(cb_start) => { - match cb_start.start { - Some(start) => { - match start { - ContentBlockStart::ToolUse(_) => unimplemented!("The Bedrock provider has not implemented tool use yet"), - ContentBlockStart::Unknown => { - return Some((Some(Err(anyhow!("Unknown Delta type, Update the Bedrock SDK"))), state)) - } - _ => {} + if let Some(start) = cb_start.start { + match start { + ContentBlockStart::ToolUse(_) => { + return Some(( + Some(Err(anyhow!("The Bedrock provider has not implemented tool use yet"))), + state, + )) } + _ => {} } - None => {} } } - ConverseStreamOutput::ContentBlockStop(cb_stop) => { - unimplemented!("The Bedrock provider has not implemented tool use yet, this event will only be received on tool use") - } - ConverseStreamOutput::MessageStart(message) => {} - ConverseStreamOutput::MessageStop(_) => { - // This contains information if response generation has stopped for any reason - } - ConverseStreamOutput::Metadata(metadata) => { - // This contains stream metadata including token usage, metrics and traces for guardrail behaviour - } - ConverseStreamOutput::Unknown => { - return Some((Some(Err(anyhow!("Unknown event type, Update the Bedrock SDK"))), state)) + ConverseStreamOutput::ContentBlockStop(_) => { + return Some(( + Some(Err(anyhow!("The Bedrock provider has not implemented tool use yet, this event will only be received on tool use"))), + state, + )) } + ConverseStreamOutput::MessageStart(_) | + ConverseStreamOutput::MessageStop(_) | + ConverseStreamOutput::Metadata(_) => {} _ => {} }, - Err(err) => { - return Some((Some(Err(anyhow!(err))), state)); - } + Err(err) => return Some((Some(Err(anyhow!(err))), state)), } + + } None @@ -422,8 +452,9 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { Arc::new(BedrockModel { id: LanguageModelId::from(model.id().to_string()), model, + runtime_client: self.runtime_client.clone(), // too many copies of the bedrock client created here, figure out how to safely share it state: self.state.clone(), - request_limiter: RateLimiter::new(4) + request_limiter: RateLimiter::new(4), }) as Arc }) .collect() @@ -443,7 +474,8 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { } fn reset_credentials(&self, cx: &mut AppContext) -> Task> { - self.state.update(cx, |state, cx| state.reset_credentials(cx)) + self.state + .update(cx, |state, cx| state.reset_credentials(cx)) } } @@ -471,7 +503,7 @@ impl ConfigurationView { cx.observe(&state, |_, _, cx| { cx.notify(); }) - .detach(); + .detach(); let load_credentials_task = Some(cx.spawn({ let state = state.clone(); @@ -487,7 +519,7 @@ impl ConfigurationView { this.load_credentials_task = None; cx.notify(); }) - .log_err(); + .log_err(); } })); @@ -513,32 +545,28 @@ impl ConfigurationView { } fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { - let access_key_id = self.access_key_id_editor - .read(cx) - .text(cx) - .to_string(); - let secret_access_key = self.secret_access_key_editor - .read(cx) - .text(cx) - .to_string(); - let region = self.region_editor - .read(cx) - .text(cx) - .to_string(); + let access_key_id = self.access_key_id_editor.read(cx).text(cx).to_string(); + let secret_access_key = self.secret_access_key_editor.read(cx).text(cx).to_string(); + let region = self.region_editor.read(cx).text(cx).to_string(); let state = self.state.clone(); cx.spawn(|_, mut cx| async move { state - .update(&mut cx, |state, cx| state.set_credentials(access_key_id, secret_access_key, region, cx))? + .update(&mut cx, |state, cx| { + state.set_credentials(access_key_id, secret_access_key, region, cx) + })? .await }) - .detach_and_log_err(cx); + .detach_and_log_err(cx); } fn reset_credentials(&mut self, cx: &mut ViewContext) { - self.access_key_id_editor.update(cx, |editor, cx| editor.set_text("", cx)); - self.secret_access_key_editor.update(cx, |editor, cx| editor.set_text("", cx)); - self.region_editor.update(cx, |editor, cx| editor.set_text("", cx)); + self.access_key_id_editor + .update(cx, |editor, cx| editor.set_text("", cx)); + self.secret_access_key_editor + .update(cx, |editor, cx| editor.set_text("", cx)); + self.region_editor + .update(cx, |editor, cx| editor.set_text("", cx)); let state = self.state.clone(); cx.spawn(|_, mut cx| async move { @@ -546,7 +574,7 @@ impl ConfigurationView { .update(&mut cx, |state, cx| state.reset_credentials(cx))? .await }) - .detach_and_log_err(cx); + .detach_and_log_err(cx); } fn make_text_style(&self, cx: &ViewContext) -> TextStyle { @@ -685,5 +713,3 @@ impl Render for ConfigurationView { } } } - - From 97eb67707530853b3ddee622bfcd092c01d28673 Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Fri, 29 Nov 2024 14:35:04 -0500 Subject: [PATCH 11/13] Helper functions --- Cargo.lock | 2 + Cargo.toml | 1 + crates/bedrock/src/bedrock.rs | 57 +++++--------- crates/collab/src/llm.rs | 4 +- crates/collab/src/llm/authorization.rs | 1 + crates/http_client/Cargo.toml | 1 + crates/http_client/src/http_client.rs | 77 ++++++++++++++++++- crates/language_models/Cargo.toml | 1 + .../language_models/src/provider/bedrock.rs | 23 ++++-- crates/language_models/src/settings.rs | 6 +- 10 files changed, 125 insertions(+), 48 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3bb78680956a6..7b5c9592c8712 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5656,6 +5656,7 @@ name = "http_client" version = "0.1.0" dependencies = [ "anyhow", + "aws-smithy-runtime-api", "bytes 1.8.0", "derive_more", "futures 0.3.31", @@ -6627,6 +6628,7 @@ dependencies = [ "anyhow", "aws-config", "aws-credential-types", + "aws-smithy-runtime-api", "bedrock", "client", "collections", diff --git a/Cargo.toml b/Cargo.toml index bdf5608badd5d..d5892bf3d12fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -336,6 +336,7 @@ async-trait = "0.1" async-tungstenite = "0.28" async-watch = "0.3.1" async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] } +aws-smithy-runtime-api = { version = "1.7.3" } aws-credential-types = { version = "1.2.1", features = ["hardcoded-credentials"] } aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } aws-sdk-bedrockruntime = { version = "1.57.0", features = ["behavior-version-latest"]} diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index 58601340bde14..1317ff37c96bf 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -1,15 +1,14 @@ mod models; use anyhow::{anyhow, Context, Result}; +use aws_sdk_bedrockruntime::config::SharedHttpClient; +use http_client::HttpClient; use serde::{Deserialize, Serialize}; -use std::any::Any; -use std::str::FromStr; +use std::sync::Arc; use thiserror::Error; use aws_sdk_bedrockruntime as bedrock; pub use aws_sdk_bedrockruntime as bedrock_client; -use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, Message}; -use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError; pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest; pub use bedrock::types::ContentBlock as BedrockRequestContent; pub use bedrock::types::ConversationRole as BedrockRole; @@ -17,8 +16,6 @@ use bedrock::types::ConverseOutput as Response; pub use bedrock::types::ConverseStreamOutput as BedrockStreamingResponse; pub use bedrock::types::Message as BedrockMessage; pub use bedrock::types::ResponseStream as BedrockResponseStream; -use futures::stream::BoxStream; -use futures::StreamExt; use strum::Display; //TODO: Re-export the Bedrock stuff // https://doc.rust-lang.org/rustdoc/write-documentation/re-exports.html @@ -32,15 +29,13 @@ pub async fn complete( let mut response = bedrock::Client::converse(client) .model_id(request.model.clone()) .set_messages(request.messages.into()) - .send().await.context("Failed to send request to Bedrock"); + .send() + .await + .context("Failed to send request to Bedrock"); match response { - Ok(output) => { - Ok(output.output.unwrap()) - } - Err(err) => { - Err(BedrockError::Other(err)) - } + Ok(output) => Ok(output.output.unwrap()), + Err(err) => Err(BedrockError::Other(err)), } } @@ -48,33 +43,23 @@ pub async fn stream_completion( client: &bedrock::Client, request: Request, ) -> Result, BedrockError> { - let response = bedrock::Client::converse_stream(client) .model_id(request.model) - .set_messages(request.messages.into()).send().await; - + .set_messages(request.messages.into()) + .send() + .await; match response { - Ok(mut output) => { - match output.stream.recv().await { - Ok(resp) => { - match resp { - None => { - Ok(None) - } - Some(output) => { - Ok(Some(output)) - } - } - } - Err(e) => { - Err(BedrockError::ClientError(anyhow!("Failed to receive response from Bedrock"))) - } - } + Ok(mut output) => match output.stream.recv().await { + Ok(resp) => match resp { + None => Ok(None), + Some(output) => Ok(Some(output)), + }, + Err(e) => Err(BedrockError::ClientError(anyhow!( + "Failed to receive response from Bedrock" + ))), }, - Err(e) => Err( - BedrockError::ClientError(anyhow!(e)) - ), + Err(e) => Err(BedrockError::ClientError(anyhow!(e))), } } @@ -108,5 +93,5 @@ pub struct Metadata { pub enum BedrockError { ClientError(anyhow::Error), ExtensionError(anyhow::Error), - Other(anyhow::Error) + Other(anyhow::Error), } diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 114fd19bc77f6..9a8b465b6305e 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -406,9 +406,7 @@ async fn perform_completion( }) .boxed() } - LanguageModelProvider::Bedrock => { - Err("Unimplemented").boxed() - } + LanguageModelProvider::Bedrock => Err(anyhow!("Unimplemented")).boxed(), }; Ok(Response::new(Body::wrap_stream(TokenCountingStream { diff --git a/crates/collab/src/llm/authorization.rs b/crates/collab/src/llm/authorization.rs index 9f82af51c39b7..2cfb841c9546c 100644 --- a/crates/collab/src/llm/authorization.rs +++ b/crates/collab/src/llm/authorization.rs @@ -77,6 +77,7 @@ fn authorize_access_for_country( LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code), LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code), LanguageModelProvider::Google => google_ai::is_supported_country(country_code), + LanguageModelProvider::Bedrock => todo!(), }; if !is_country_supported_by_provider { Err(Error::http( diff --git a/crates/http_client/Cargo.toml b/crates/http_client/Cargo.toml index ac8e254b84f60..fd0d9333fe47f 100644 --- a/crates/http_client/Cargo.toml +++ b/crates/http_client/Cargo.toml @@ -17,6 +17,7 @@ doctest = true [dependencies] bytes.workspace = true +aws-smithy-runtime-api.workspace = true anyhow.workspace = true derive_more.workspace = true futures.workspace = true diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index ec3a4e03c40f1..28950a86fa62b 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -3,13 +3,20 @@ pub mod github; pub use anyhow::{anyhow, Result}; pub use async_body::{AsyncBody, Inner}; +use aws_smithy_runtime_api::http::StatusCode as AwsStatusCode; use derive_more::Deref; pub use http::{self, Method, Request, Response, StatusCode, Uri}; +use std::fmt; +use aws_smithy_runtime_api::client::http::{HttpConnector, HttpConnectorFuture}; +use aws_smithy_runtime_api::client::orchestrator::{ + HttpRequest as AwsRequest, HttpResponse as AwsResponse, +}; use futures::future::BoxFuture; use http::request::Builder; + #[cfg(feature = "test-support")] -use std::fmt; +use std::fmt::{Debug, Formatter}; use std::{ any::type_name, sync::{Arc, Mutex}, @@ -326,6 +333,74 @@ impl HttpClient for BlockedHttpClient { } } +impl Debug for HttpClientWithUrl { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + todo!() + } +} + +impl HttpConnector for HttpClientWithUrl { + fn call(&self, request: HttpRequest) -> HttpConnectorFuture { + let response = self.client.send(request); + } +} + +// Helper to convert AWS SDK request to your HttpClient request +fn convert_aws_request(aws_request: AwsRequest) -> Result, anyhow::Error> { + let owned_request = match aws_request.try_clone() { + Some(req) => req, + None => { + return Err(anyhow!( + "Failed to clone the AWS request, this is likely a bug in the SDK" + )) + } + }; + + // Convert the request details + let mut builder = http::Request::builder() + .method(owned_request.method()) + .uri(owned_request.uri()); + + // Add headers + for (name, value) in owned_request.headers() { + builder = builder.header(name, value); + } + + // Convert body + let body = AsyncBody::from(owned_request.body()); + + Ok(builder.body(body)?) +} + +// Helper to convert your response to AWS SDK response +fn convert_to_aws_response( + response: http::Response, +) -> Result { + let (parts, body) = response.into_parts(); + + let mut aws_response = + AwsResponse::new(AwsStatusCode::from(response.status()), response.body()); + + // Copy headers + for (name, value) in parts.headers { + let val = match value.to_str() { + Ok(val) => val, + Err(e) => { + return Err(anyhow!("Failed to convert header value to string: {}", e)); + } + }; + let header_name = match name { + None => { + return Err(anyhow!("Failed to convert header name to string")); + } + Some(header) => header.as_str(), + }; + aws_response.headers_mut().insert(header_name, val); + } + + Ok(aws_response) +} + #[cfg(feature = "test-support")] type FakeHttpHandler = Box< dyn Fn(Request) -> BoxFuture<'static, Result, anyhow::Error>> diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index b836d70cd263e..fb14afc23a9d9 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -14,6 +14,7 @@ path = "src/language_models.rs" [dependencies] anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true +aws-smithy-runtime-api.workspace = true aws-credential-types = { workspace = true, features = ["hardcoded-credentials"] } aws-config = { workspace = true, features = ["behavior-version-latest"]} bedrock.workspace = true diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 7a633d72e57f4..0792ade2a17ee 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -2,6 +2,9 @@ use crate::AllLanguageModelSettings; use anyhow::{anyhow, Context as _, Result}; use aws_config::Region; use aws_credential_types::Credentials; +use aws_smithy_runtime_api::client::http::{http_client_fn, SharedHttpClient, SharedHttpConnector}; +use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse}; +use aws_smithy_runtime_api::http::StatusCode; use bedrock::bedrock_client::types::{ ContentBlockDelta, ContentBlockStart, ContentBlockStartEvent, ConverseStreamOutput, }; @@ -15,7 +18,7 @@ use gpui::{ AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle, View, WhiteSpace, }; -use http_client::HttpClient; +use http_client::{http, AsyncBody, HttpClient}; use language_model::{ LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, @@ -27,10 +30,8 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use settings::{Settings, SettingsStore}; use std::pin::Pin; -use std::str::FromStr; use std::sync::Arc; use strum::IntoEnumIterator; -use bedrock::bedrock_client::config::{IntoShared, SharedHttpClient}; use theme::ThemeSettings; use ui::{prelude::*, Icon, IconName, Tooltip}; use util::{maybe, ResultExt}; @@ -220,8 +221,20 @@ impl BedrockLanguageModelProvider { .or_else(|| Some(AmazonBedrockCredentials::default())) .unwrap(); + let aws_connector = http_client_fn(move |settings, components| { + let client = http_client.clone(); + SharedHttpConnector::new(move |request| { + let client = client.clone(); + Box::pin(async move { + let request = convert_aws_request(request)?; + let response = client.send(request).await?; + Ok(convert_to_aws_response(response)?) + }) + }) + }); + let client_config = Config::builder() - .http_client(SharedHttpClient::new(http_client.as_ref())) + .http_client(SharedHttpClient::new(aws_connector)) .region(Region::new(region_def)) .credentials_provider(Credentials::from_keys( &creds_clone.clone().access_key_id, @@ -251,7 +264,7 @@ impl BedrockModel { fn stream_completion( &self, request: bedrock::Request, - _: &AsyncAppContext + _: &AsyncAppContext, ) -> BoxFuture< 'static, Result>>, diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index 36c2ff5554759..17df8f3ad798c 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -8,6 +8,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{update_settings_file, Settings, SettingsSources}; +use crate::provider::bedrock::{AmazonBedrockCredentials, AmazonBedrockSettings}; use crate::provider::{ self, anthropic::AnthropicSettings, @@ -17,7 +18,6 @@ use crate::provider::{ ollama::OllamaSettings, open_ai::OpenAiSettings, }; -use crate::provider::bedrock::{AmazonBedrockCredentials, AmazonBedrockSettings}; /// Initializes the language model settings. pub fn init(fs: Arc, cx: &mut AppContext) { @@ -54,8 +54,8 @@ pub fn init(fs: Arc, cx: &mut AppContext) { #[derive(Default)] pub struct AllLanguageModelSettings { - pub bedrock: AmazonBedrockSettings, pub anthropic: AnthropicSettings, + pub bedrock: AmazonBedrockSettings, pub ollama: OllamaSettings, pub openai: OpenAiSettings, pub zed_dot_dev: ZedDotDevSettings, @@ -80,7 +80,7 @@ pub struct BedrockSettingsContent { pub region: Option, pub access_key_id: Option, pub secret_access_key: Option, - pub available_models: Option> + pub available_models: Option>, } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] From 5db57d529a90b4f660a7378919c743b8aade9d8c Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Mon, 2 Dec 2024 15:57:24 -0500 Subject: [PATCH 12/13] Trying some additional things --- crates/bedrock/src/bedrock.rs | 31 +++++++++++++------ crates/http_client/src/http_client.rs | 29 ++++++++++------- .../language_models/src/provider/bedrock.rs | 5 +-- 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index 1317ff37c96bf..cb4bbc43fb722 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -2,6 +2,8 @@ mod models; use anyhow::{anyhow, Context, Result}; use aws_sdk_bedrockruntime::config::SharedHttpClient; +use futures::stream; +use futures::StreamExt; use http_client::HttpClient; use serde::{Deserialize, Serialize}; use std::sync::Arc; @@ -16,6 +18,7 @@ use bedrock::types::ConverseOutput as Response; pub use bedrock::types::ConverseStreamOutput as BedrockStreamingResponse; pub use bedrock::types::Message as BedrockMessage; pub use bedrock::types::ResponseStream as BedrockResponseStream; +use futures::stream::BoxStream; use strum::Display; //TODO: Re-export the Bedrock stuff // https://doc.rust-lang.org/rustdoc/write-documentation/re-exports.html @@ -42,7 +45,7 @@ pub async fn complete( pub async fn stream_completion( client: &bedrock::Client, request: Request, -) -> Result, BedrockError> { +) -> Result, BedrockError>>, BedrockError> { let response = bedrock::Client::converse_stream(client) .model_id(request.model) .set_messages(request.messages.into()) @@ -50,15 +53,23 @@ pub async fn stream_completion( .await; match response { - Ok(mut output) => match output.stream.recv().await { - Ok(resp) => match resp { - None => Ok(None), - Some(output) => Ok(Some(output)), - }, - Err(e) => Err(BedrockError::ClientError(anyhow!( - "Failed to receive response from Bedrock" - ))), - }, + Ok(mut output) => { + let stream = stream::unfold(output.stream, |mut stream| async move { + match stream.recv().await { + Ok(Some(output)) => Some((Ok(Some(output)), stream)), + Ok(None) => Some((Ok(None), stream)), + Err(e) => Some(( + Err(BedrockError::ClientError(anyhow!( + "Failed to receive response from Bedrock" + ))), + stream, + )), + } + }) + .boxed(); + + Ok(stream) + } Err(e) => Err(BedrockError::ClientError(anyhow!(e))), } } diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index 28950a86fa62b..d0a5ac83cca00 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -9,9 +9,7 @@ pub use http::{self, Method, Request, Response, StatusCode, Uri}; use std::fmt; use aws_smithy_runtime_api::client::http::{HttpConnector, HttpConnectorFuture}; -use aws_smithy_runtime_api::client::orchestrator::{ - HttpRequest as AwsRequest, HttpResponse as AwsResponse, -}; +use aws_smithy_runtime_api::client::orchestrator::{HttpRequest as AwsRequest, HttpRequest, HttpResponse as AwsResponse}; use futures::future::BoxFuture; use http::request::Builder; @@ -333,20 +331,27 @@ impl HttpClient for BlockedHttpClient { } } -impl Debug for HttpClientWithUrl { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - todo!() - } +#[derive(Debug)] +pub struct AwsHttpClient { + [] + client: HttpClientWithProxy } -impl HttpConnector for HttpClientWithUrl { - fn call(&self, request: HttpRequest) -> HttpConnectorFuture { +impl HttpConnector for AwsHttpClient { + fn call(&self, request: AwsRequest) -> HttpConnectorFuture { + let request = convert_aws_request(request).unwrap(); + let response = self.client.send(request); + + let response = response.map(|response| { + response.map(|response| convert_to_aws_response(response).unwrap()) + }); + } } // Helper to convert AWS SDK request to your HttpClient request -fn convert_aws_request(aws_request: AwsRequest) -> Result, anyhow::Error> { +fn convert_aws_request(aws_request: AwsRequest) -> Result, anyhow::Error> { let owned_request = match aws_request.try_clone() { Some(req) => req, None => { @@ -374,12 +379,12 @@ fn convert_aws_request(aws_request: AwsRequest) -> Result, + response: Response, ) -> Result { let (parts, body) = response.into_parts(); let mut aws_response = - AwsResponse::new(AwsStatusCode::from(response.status()), response.body()); + AwsResponse::new(AwsStatusCode::from(parts.status), body); // Copy headers for (name, value) in parts.headers { diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 0792ade2a17ee..155132c1bd04b 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -270,8 +270,9 @@ impl BedrockModel { Result>>, > { async move { - let request = bedrock::stream_completion(&self.runtime_client, request); - request.await.context("failed to perform stream completion") + let response = bedrock::stream_completion(&self.runtime_client, request); + + let unwrappedResponse: ConverseStreamOutput = response.await; } .boxed() } From 0d2bb3ffe2298758d7fe75cbe16473671d6bcbcb Mon Sep 17 00:00:00 2001 From: Shardul Vaidya Date: Fri, 6 Dec 2024 13:30:12 -0500 Subject: [PATCH 13/13] Added Nova Models --- crates/bedrock/src/models.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index fd004f2c8efab..509c845cfad91 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -14,6 +14,10 @@ pub enum Model { Claude3Sonnet, #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")] Claude3_5Haiku, + // Amazon Nova Models + AmazonNovaLite, + AmazonNovaMicro, + AmazonNovaPro, // AI21 models AI21J2GrandeInstruct, AI21J2JumboInstruct, @@ -80,6 +84,9 @@ impl Model { Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0", Model::Claude3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0", Model::Claude3_5Haiku => "anthropic.claude-3-5-haiku-20241022-v1:0", + Model::AmazonNovaLite => "amazon.nova-lite-v1:0", + Model::AmazonNovaMicro => "amazon.nova-micro-v1:0", + Model::AmazonNovaPro => "amazon.nova-pro-v1:0", Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct", Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct", Model::AI21J2Mid => "ai21.j2-mid", @@ -118,6 +125,9 @@ impl Model { Self::Claude3Opus => "Claude 3 Opus", Self::Claude3Sonnet => "Claude 3 Sonnet", Self::Claude3_5Haiku => "Claude 3.5 Haiku", + Self::AmazonNovaLite => "Amazon Nova Lite", + Self::AmazonNovaMicro => "Amazon Nova Micro", + Self::AmazonNovaPro => "Amazon Nova Pro", Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct", Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct", Self::AI21J2Mid => "AI21 Jurassic2 Mid",