Skip to content

Commit

Permalink
Streaming worker tools
Browse files Browse the repository at this point in the history
Summary: Worker tools with tons (millions) of actions queued up will have millions of async tasks consuming memory. This can reach the order of gigabytes. To avoid this, workers who's execution model is inherently async can opt to use streaming execution. This keeps a single async stack for the all actions. See next diff for an example. I made this opt in because writing workers in this way is complicated and usually not desirable.

Reviewed By: christolliday

Differential Revision: D67939003

fbshipit-source-id: eb179718221370c0d8e4ebfbf2d0e1b38a38e8a7
  • Loading branch information
Ron Mordechai authored and facebook-github-bot committed Jan 15, 2025
1 parent 05a4a94 commit a47d304
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 8 deletions.
1 change: 1 addition & 0 deletions app/buck2_execute_impl/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ rust_library(
"fbsource//third-party/rust:assert_matches",
],
deps = [
"fbsource//third-party/rust:anyhow",
"fbsource//third-party/rust:async-condvar-fair",
"fbsource//third-party/rust:async-trait",
"fbsource//third-party/rust:chrono",
Expand Down
119 changes: 111 additions & 8 deletions app/buck2_execute_impl/src/executors/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

use std::collections::HashMap;
use std::ffi::OsString;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;

Expand All @@ -33,17 +35,25 @@ use buck2_execute::execute::result::CommandExecutionResult;
use buck2_forkserver::client::ForkserverClient;
use buck2_forkserver::run::GatherOutputStatus;
use buck2_worker_proto::execute_command::EnvironmentEntry;
use buck2_worker_proto::worker_client::WorkerClient;
use buck2_worker_proto::worker_client;
use buck2_worker_proto::worker_streaming_client;
use buck2_worker_proto::ExecuteCommand;
use buck2_worker_proto::ExecuteCommandStream;
use buck2_worker_proto::ExecuteResponse;
use buck2_worker_proto::ExecuteResponseStream;
use dashmap::DashMap;
use dupe::Dupe;
use futures::future::BoxFuture;
use futures::future::Shared;
use futures::FutureExt;
use host_sharing::HostSharingBroker;
use host_sharing::HostSharingStrategy;
use indexmap::IndexMap;
use tokio::sync::mpsc::UnboundedSender;
use tokio::task::JoinHandle;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::transport::Channel;
use tonic::Status;

const MAX_MESSAGE_SIZE_BYTES: usize = 8 * 1024 * 1024; // 8MB

Expand Down Expand Up @@ -308,9 +318,14 @@ async fn spawn_worker(
});

tracing::info!("Connected to socket for spawned worker: {}", socket_path);
let client = WorkerClient::new(channel)
.max_encoding_message_size(MAX_MESSAGE_SIZE_BYTES)
.max_decoding_message_size(MAX_MESSAGE_SIZE_BYTES);
let client = if worker_spec.streaming {
WorkerClient::stream(channel)
.await
.map_err(|e| WorkerInitError::SpawnFailed(e.to_string()))?
} else {
WorkerClient::single(channel)
};

Ok(WorkerHandle::new(
client,
child_exited_observer,
Expand Down Expand Up @@ -394,8 +409,97 @@ impl WorkerPool {
}
}

#[derive(Clone)]
enum WorkerClient {
Single(worker_client::WorkerClient<Channel>),
Stream {
ids: Arc<AtomicU64>,
stream: UnboundedSender<ExecuteCommandStream>,
waiters: Arc<DashMap<u64, tokio::sync::oneshot::Sender<ExecuteResponseStream>>>,
},
}

impl WorkerClient {
fn single(channel: Channel) -> Self {
Self::Single(
worker_client::WorkerClient::new(channel)
.max_encoding_message_size(MAX_MESSAGE_SIZE_BYTES)
.max_decoding_message_size(MAX_MESSAGE_SIZE_BYTES),
)
}

async fn stream(channel: Channel) -> Result<Self, Status> {
let mut client = worker_streaming_client::WorkerStreamingClient::new(channel)
.max_encoding_message_size(MAX_MESSAGE_SIZE_BYTES)
.max_decoding_message_size(MAX_MESSAGE_SIZE_BYTES);
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let stream = client
.execute_stream(tonic::Request::new(UnboundedReceiverStream::new(rx)))
.await?;
let waiters: Arc<DashMap<u64, tokio::sync::oneshot::Sender<ExecuteResponseStream>>> =
Default::default();
{
let waiters = waiters.dupe();
tokio::spawn(async move {
use futures::StreamExt;

let mut stream = stream.into_inner();
while let Some(response) = stream.next().await {
let response = response.unwrap();
match waiters.remove(&response.id) {
Some(waiter) => {
let id = response.id;
if waiter.1.send(response).is_err() {
tracing::warn!(
id = id,
"Error passing streaming worker response to waiter"
);
}
}
None => {
tracing::warn!(
id = response.id,
"Missing waiter for streaming worker response",
);
}
};
}
});
}
Ok(Self::Stream {
ids: Default::default(),
stream: tx,
waiters,
})
}

async fn execute(&mut self, request: ExecuteCommand) -> anyhow::Result<ExecuteResponse> {
match self {
Self::Single(client) => Ok(client
.execute(request)
.await
.map(|response| response.into_inner())?),
Self::Stream {
ids,
stream,
waiters,
} => {
let id = ids.fetch_add(1, Ordering::Acquire);
let req = ExecuteCommandStream {
request: Some(request),
id,
};
let (tx, rx) = tokio::sync::oneshot::channel();
waiters.insert(id, tx);
stream.send(req)?;
Ok(rx.await.map(|response| response.response.unwrap())?)
}
}
}
}

pub struct WorkerHandle {
client: WorkerClient<Channel>,
client: WorkerClient,
child_exited_observer: Arc<dyn LivelinessObserver>,
stdout_path: AbsNormPathBuf,
stderr_path: AbsNormPathBuf,
Expand All @@ -404,7 +508,7 @@ pub struct WorkerHandle {

impl WorkerHandle {
fn new(
client: WorkerClient<Channel>,
client: WorkerClient,
child_exited_observer: Arc<dyn LivelinessObserver>,
stdout_path: AbsNormPathBuf,
stderr_path: AbsNormPathBuf,
Expand Down Expand Up @@ -461,8 +565,7 @@ impl WorkerHandle {
tokio::select! {
response = client.execute(request) => {
match response {
Ok(response) => {
let exec_response: ExecuteResponse = response.into_inner();
Ok(exec_response) => {
tracing::info!("Worker response:\n{:?}\n", exec_response);
if let Some(timeout) = exec_response.timed_out_after_s {
(
Expand Down
17 changes: 17 additions & 0 deletions app/buck2_worker_proto/worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,20 @@ service Worker {

rpc Exec(stream ExecuteEvent) returns (ExecuteResponse) {};
}

message ExecuteCommandStream {
ExecuteCommand request = 1;
uint64 id = 2;
}

message ExecuteResponseStream {
ExecuteResponse response = 1;
uint64 id = 2;
}

// This is its own interface because it significantly complicates worker implementation.
// Most workers do not need streaming, nor do they benefit from it.
service WorkerStreaming {
rpc ExecuteStream(stream ExecuteCommandStream)
returns (stream ExecuteResponseStream) {};
}
8 changes: 8 additions & 0 deletions tests/e2e/build/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def target_name(identity: str) -> str:
]
assert what_ran_matching == expected, what_ran

# Streaming execution demo
res = await buck.build(*worker_args, package + ":gen_worker_run_out_streaming")
output = res.get_build_report().output_for_target(
package + ":gen_worker_run_out_streaming"
)
assert output.read_text() == "hello worker"
assert len(await read_what_ran_for_executor(buck, "Worker")) == 1

# TODO(ctolliday) re-enable once cancellation is in place
# assert_executed(
# await read_what_ran(buck),
Expand Down

0 comments on commit a47d304

Please sign in to comment.