From a4b818ec6bc9d438dbd548854bdf2638d2824332 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Tue, 29 Oct 2024 14:43:34 -0600 Subject: [PATCH] Better handle interrupted connections for shared SSH (#19925) Co-Authored-By: Mikayla --- crates/remote/src/ssh_session.rs | 34 ++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index 16b76628710a16..ecee6223bff47d 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -1288,6 +1288,7 @@ impl SshRemoteConnection { ) -> Result { use futures::AsyncWriteExt as _; use futures::{io::BufReader, AsyncBufReadExt as _}; + use smol::net::unix::UnixStream; use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener}; use util::ResultExt as _; @@ -1304,6 +1305,9 @@ impl SshRemoteConnection { let listener = UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?; + let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::(); + let mut kill_tx = Some(askpass_kill_master_tx); + let askpass_task = cx.spawn({ let delegate = delegate.clone(); |mut cx| async move { @@ -1327,6 +1331,11 @@ impl SshRemoteConnection { .log_err() { stream.write_all(password.as_bytes()).await.log_err(); + } else { + if let Some(kill_tx) = kill_tx.take() { + kill_tx.send(stream).log_err(); + break; + } } } } @@ -1347,6 +1356,7 @@ impl SshRemoteConnection { // the connection and keep it open, allowing other ssh commands to reuse it // via a control socket. let socket_path = temp_dir.path().join("ssh.sock"); + let mut master_process = process::Command::new("ssh") .stdin(Stdio::null()) .stdout(Stdio::piped()) @@ -1369,20 +1379,28 @@ impl SshRemoteConnection { // Wait for this ssh process to close its stdout, indicating that authentication // has completed. - let stdout = master_process.stdout.as_mut().unwrap(); + let mut stdout = master_process.stdout.take().unwrap(); let mut output = Vec::new(); let connection_timeout = Duration::from_secs(10); let result = select_biased! { _ = askpass_opened_rx.fuse() => { - // If the askpass script has opened, that means the user is typing - // their password, in which case we don't want to timeout anymore, - // since we know a connection has been established. - stdout.read_to_end(&mut output).await?; - Ok(()) + select_biased! { + stream = askpass_kill_master_rx.fuse() => { + master_process.kill().ok(); + drop(stream); + Err(anyhow!("SSH connection canceled")) + } + // If the askpass script has opened, that means the user is typing + // their password, in which case we don't want to timeout anymore, + // since we know a connection has been established. + result = stdout.read_to_end(&mut output).fuse() => { + result?; + Ok(()) + } + } } - result = stdout.read_to_end(&mut output).fuse() => { - result?; + _ = stdout.read_to_end(&mut output).fuse() => { Ok(()) } _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {