diff --git a/app/buck2_execute_impl/BUCK b/app/buck2_execute_impl/BUCK index abeaf62ba365..45cbf9def8a0 100644 --- a/app/buck2_execute_impl/BUCK +++ b/app/buck2_execute_impl/BUCK @@ -27,6 +27,7 @@ rust_library( "fbsource//third-party/rust:assert_matches", ], deps = [ + "fbsource//third-party/rust:anyhow", "fbsource//third-party/rust:async-condvar-fair", "fbsource//third-party/rust:async-trait", "fbsource//third-party/rust:chrono", diff --git a/app/buck2_execute_impl/src/executors/worker.rs b/app/buck2_execute_impl/src/executors/worker.rs index d19bac50ea3e..046382c4f251 100644 --- a/app/buck2_execute_impl/src/executors/worker.rs +++ b/app/buck2_execute_impl/src/executors/worker.rs @@ -9,6 +9,8 @@ use std::collections::HashMap; use std::ffi::OsString; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::Duration; @@ -33,17 +35,25 @@ use buck2_execute::execute::result::CommandExecutionResult; use buck2_forkserver::client::ForkserverClient; use buck2_forkserver::run::GatherOutputStatus; use buck2_worker_proto::execute_command::EnvironmentEntry; -use buck2_worker_proto::worker_client::WorkerClient; +use buck2_worker_proto::worker_client; +use buck2_worker_proto::worker_streaming_client; use buck2_worker_proto::ExecuteCommand; +use buck2_worker_proto::ExecuteCommandStream; use buck2_worker_proto::ExecuteResponse; +use buck2_worker_proto::ExecuteResponseStream; +use dashmap::DashMap; +use dupe::Dupe; use futures::future::BoxFuture; use futures::future::Shared; use futures::FutureExt; use host_sharing::HostSharingBroker; use host_sharing::HostSharingStrategy; use indexmap::IndexMap; +use tokio::sync::mpsc::UnboundedSender; use tokio::task::JoinHandle; +use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::transport::Channel; +use tonic::Status; const MAX_MESSAGE_SIZE_BYTES: usize = 8 * 1024 * 1024; // 8MB @@ -308,9 +318,14 @@ async fn spawn_worker( }); tracing::info!("Connected to socket for spawned worker: {}", socket_path); - let client = WorkerClient::new(channel) - .max_encoding_message_size(MAX_MESSAGE_SIZE_BYTES) - .max_decoding_message_size(MAX_MESSAGE_SIZE_BYTES); + let client = if worker_spec.streaming { + WorkerClient::stream(channel) + .await + .map_err(|e| WorkerInitError::SpawnFailed(e.to_string()))? + } else { + WorkerClient::single(channel) + }; + Ok(WorkerHandle::new( client, child_exited_observer, @@ -394,8 +409,97 @@ impl WorkerPool { } } +#[derive(Clone)] +enum WorkerClient { + Single(worker_client::WorkerClient), + Stream { + ids: Arc, + stream: UnboundedSender, + waiters: Arc>>, + }, +} + +impl WorkerClient { + fn single(channel: Channel) -> Self { + Self::Single( + worker_client::WorkerClient::new(channel) + .max_encoding_message_size(MAX_MESSAGE_SIZE_BYTES) + .max_decoding_message_size(MAX_MESSAGE_SIZE_BYTES), + ) + } + + async fn stream(channel: Channel) -> Result { + let mut client = worker_streaming_client::WorkerStreamingClient::new(channel) + .max_encoding_message_size(MAX_MESSAGE_SIZE_BYTES) + .max_decoding_message_size(MAX_MESSAGE_SIZE_BYTES); + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let stream = client + .execute_stream(tonic::Request::new(UnboundedReceiverStream::new(rx))) + .await?; + let waiters: Arc>> = + Default::default(); + { + let waiters = waiters.dupe(); + tokio::spawn(async move { + use futures::StreamExt; + + let mut stream = stream.into_inner(); + while let Some(response) = stream.next().await { + let response = response.unwrap(); + match waiters.remove(&response.id) { + Some(waiter) => { + let id = response.id; + if waiter.1.send(response).is_err() { + tracing::warn!( + id = id, + "Error passing streaming worker response to waiter" + ); + } + } + None => { + tracing::warn!( + id = response.id, + "Missing waiter for streaming worker response", + ); + } + }; + } + }); + } + Ok(Self::Stream { + ids: Default::default(), + stream: tx, + waiters, + }) + } + + async fn execute(&mut self, request: ExecuteCommand) -> anyhow::Result { + match self { + Self::Single(client) => Ok(client + .execute(request) + .await + .map(|response| response.into_inner())?), + Self::Stream { + ids, + stream, + waiters, + } => { + let id = ids.fetch_add(1, Ordering::Acquire); + let req = ExecuteCommandStream { + request: Some(request), + id, + }; + let (tx, rx) = tokio::sync::oneshot::channel(); + waiters.insert(id, tx); + stream.send(req)?; + Ok(rx.await.map(|response| response.response.unwrap())?) + } + } + } +} + pub struct WorkerHandle { - client: WorkerClient, + client: WorkerClient, child_exited_observer: Arc, stdout_path: AbsNormPathBuf, stderr_path: AbsNormPathBuf, @@ -404,7 +508,7 @@ pub struct WorkerHandle { impl WorkerHandle { fn new( - client: WorkerClient, + client: WorkerClient, child_exited_observer: Arc, stdout_path: AbsNormPathBuf, stderr_path: AbsNormPathBuf, @@ -461,8 +565,7 @@ impl WorkerHandle { tokio::select! { response = client.execute(request) => { match response { - Ok(response) => { - let exec_response: ExecuteResponse = response.into_inner(); + Ok(exec_response) => { tracing::info!("Worker response:\n{:?}\n", exec_response); if let Some(timeout) = exec_response.timed_out_after_s { ( diff --git a/app/buck2_worker_proto/worker.proto b/app/buck2_worker_proto/worker.proto index d87eab214143..b24af68dfd3b 100644 --- a/app/buck2_worker_proto/worker.proto +++ b/app/buck2_worker_proto/worker.proto @@ -47,3 +47,20 @@ service Worker { rpc Exec(stream ExecuteEvent) returns (ExecuteResponse) {}; } + +message ExecuteCommandStream { + ExecuteCommand request = 1; + uint64 id = 2; +} + +message ExecuteResponseStream { + ExecuteResponse response = 1; + uint64 id = 2; +} + +// This is its own interface because it significantly complicates worker implementation. +// Most workers do not need streaming, nor do they benefit from it. +service WorkerStreaming { + rpc ExecuteStream(stream ExecuteCommandStream) + returns (stream ExecuteResponseStream) {}; +} diff --git a/tests/e2e/build/test_worker.py b/tests/e2e/build/test_worker.py index bcc640bb30a0..4e996fe9f805 100644 --- a/tests/e2e/build/test_worker.py +++ b/tests/e2e/build/test_worker.py @@ -110,6 +110,14 @@ def target_name(identity: str) -> str: ] assert what_ran_matching == expected, what_ran + # Streaming execution demo + res = await buck.build(*worker_args, package + ":gen_worker_run_out_streaming") + output = res.get_build_report().output_for_target( + package + ":gen_worker_run_out_streaming" + ) + assert output.read_text() == "hello worker" + assert len(await read_what_ran_for_executor(buck, "Worker")) == 1 + # TODO(ctolliday) re-enable once cancellation is in place # assert_executed( # await read_what_ran(buck),