diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index 4e6b6ef227c32..b4ac2731e0bdc 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -4,9 +4,9 @@ use anyhow::Result; use assistant_tool::ToolWorkingSet; use client::zed_urls; use gpui::{ - prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, - FocusableView, FontWeight, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, - WindowContext, + list, prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, Empty, EventEmitter, + FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels, Subscription, + Task, View, ViewContext, WeakView, WindowContext, }; use language_model::{LanguageModelRegistry, Role}; use language_model_selector::LanguageModelSelector; @@ -15,7 +15,7 @@ use workspace::dock::{DockPosition, Panel, PanelEvent}; use workspace::Workspace; use crate::message_editor::MessageEditor; -use crate::thread::{Message, Thread, ThreadError, ThreadEvent}; +use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent}; use crate::thread_store::ThreadStore; use crate::{NewThread, ToggleFocus, ToggleModelSelector}; @@ -35,6 +35,8 @@ pub struct AssistantPanel { #[allow(unused)] thread_store: Model, thread: Model, + thread_messages: Vec, + thread_list_state: ListState, message_editor: View, tools: Arc, last_error: Option, @@ -77,6 +79,14 @@ impl AssistantPanel { workspace: workspace.weak_handle(), thread_store, thread: thread.clone(), + thread_messages: Vec::new(), + thread_list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { + let this = cx.view().downgrade(); + move |ix, cx: &mut WindowContext| { + this.update(cx, |this, cx| this.render_message(ix, cx)) + .unwrap() + } + }), message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)), tools, last_error: None, @@ -110,6 +120,12 @@ impl AssistantPanel { self.last_error = Some(error.clone()); } ThreadEvent::StreamedCompletion => {} + ThreadEvent::MessageAdded(message_id) => { + let old_len = self.thread_messages.len(); + self.thread_messages.push(*message_id); + self.thread_list_state.splice(old_len..old_len, 1); + cx.notify(); + } ThreadEvent::UsePendingTools => { let pending_tool_uses = self .thread @@ -301,31 +317,42 @@ impl AssistantPanel { ) } - fn render_message(&self, message: Message, cx: &mut ViewContext) -> impl IntoElement { + fn render_message(&self, ix: usize, cx: &mut ViewContext) -> AnyElement { + let message_id = self.thread_messages[ix]; + let Some(message) = self.thread.read(cx).message(message_id) else { + return Empty.into_any(); + }; + let (role_icon, role_name) = match message.role { Role::User => (IconName::Person, "You"), Role::Assistant => (IconName::ZedAssistant, "Assistant"), Role::System => (IconName::Settings, "System"), }; - v_flex() - .border_1() - .border_color(cx.theme().colors().border_variant) - .rounded_md() + div() + .id(("message-container", ix)) + .p_2() .child( - h_flex() - .justify_between() - .p_1p5() - .border_b_1() + v_flex() + .border_1() .border_color(cx.theme().colors().border_variant) + .rounded_md() .child( h_flex() - .gap_2() - .child(Icon::new(role_icon).size(IconSize::Small)) - .child(Label::new(role_name).size(LabelSize::Small)), - ), + .justify_between() + .p_1p5() + .border_b_1() + .border_color(cx.theme().colors().border_variant) + .child( + h_flex() + .gap_2() + .child(Icon::new(role_icon).size(IconSize::Small)) + .child(Label::new(role_name).size(LabelSize::Small)), + ), + ) + .child(v_flex().p_1p5().child(Label::new(message.text.clone()))), ) - .child(v_flex().p_1p5().child(Label::new(message.text.clone()))) + .into_any() } fn render_last_error(&self, cx: &mut ViewContext) -> Option { @@ -477,8 +504,6 @@ impl AssistantPanel { impl Render for AssistantPanel { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - let messages = self.thread.read(cx).messages().cloned().collect::>(); - v_flex() .key_context("AssistantPanel2") .justify_between() @@ -487,20 +512,7 @@ impl Render for AssistantPanel { this.new_thread(cx); })) .child(self.render_toolbar(cx)) - .child( - v_flex() - .id("message-list") - .gap_2() - .size_full() - .p_2() - .overflow_y_scroll() - .bg(cx.theme().colors().panel_background) - .children( - messages - .into_iter() - .map(|message| self.render_message(message, cx)), - ), - ) + .child(list(self.thread_list_state.clone()).flex_1()) .child( h_flex() .border_t_1() diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index 7f789587c65b3..d1b1cf55e4657 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -56,7 +56,7 @@ impl MessageEditor { }); self.thread.update(cx, |thread, cx| { - thread.insert_user_message(user_message); + thread.insert_user_message(user_message, cx); let mut request = thread.to_completion_request(request_kind, cx); if self.use_tools { diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index a5ab415a4d7e1..43868fffffcb9 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -63,8 +63,8 @@ impl Thread { } } - pub fn messages(&self) -> impl Iterator { - self.messages.iter() + pub fn message(&self, id: MessageId) -> Option<&Message> { + self.messages.iter().find(|message| message.id == id) } pub fn tools(&self) -> &Arc { @@ -75,12 +75,14 @@ impl Thread { self.pending_tool_uses_by_id.values().collect() } - pub fn insert_user_message(&mut self, text: impl Into) { + pub fn insert_user_message(&mut self, text: impl Into, cx: &mut ModelContext) { + let id = self.next_message_id.post_inc(); self.messages.push(Message { - id: self.next_message_id.post_inc(), + id, role: Role::User, text: text.into(), }); + cx.emit(ThreadEvent::MessageAdded(id)); } pub fn to_completion_request( @@ -150,11 +152,13 @@ impl Thread { thread.update(&mut cx, |thread, cx| { match event { LanguageModelCompletionEvent::StartMessage { .. } => { + let id = thread.next_message_id.post_inc(); thread.messages.push(Message { - id: thread.next_message_id.post_inc(), + id, role: Role::Assistant, text: String::new(), }); + cx.emit(ThreadEvent::MessageAdded(id)); } LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; @@ -316,6 +320,7 @@ pub enum ThreadError { pub enum ThreadEvent { ShowError(ThreadError), StreamedCompletion, + MessageAdded(MessageId), UsePendingTools, ToolFinished { #[allow(unused)]