Skip to content

Commit

Permalink
Use tokio JoinSet to correctly fail if spawned task fails
Browse files Browse the repository at this point in the history
  • Loading branch information
madmikeross committed Dec 17, 2023
1 parent e52fbba commit 5a45fd3
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ uuid = "1.6.0"
thiserror = "1.0.50"
warp = "0.3.6"
chrono = "0.4.31"
openssl = { version = "0.10", features = ["vendored"] }
openssl = { version = "0.10", features = ["vendored"] }
63 changes: 41 additions & 22 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::convert::Infallible;
use std::sync::Arc;

use neo4rs::Graph;
use neo4rs::{Error, Graph};
use reqwest::Client;
use thiserror::Error;
use tokio::task::JoinError;
use tokio::task::{JoinError, JoinSet};
use warp::hyper::StatusCode;
use warp::reject::Reject;
use warp::reply::json;
use warp::{reply, Filter, Rejection, Reply};

Expand Down Expand Up @@ -139,7 +140,7 @@ async fn systems_refresh_handler(
client: Client,
graph: Arc<Graph>,
) -> Result<impl Reply, Rejection> {
pull_all_systems(client, graph).await.unwrap();
pull_all_systems(client, graph).await?;
Ok(reply())
}

Expand Down Expand Up @@ -186,18 +187,24 @@ async fn pull_all_stargates(client: Client, graph: Arc<Graph>) -> Result<(), Rep
async fn pull_all_systems(client: Client, graph: Arc<Graph>) -> Result<(), ReplicationError> {
let system_ids = get_system_ids(&client).await.unwrap();
println!("Received {} system ids from ESI", system_ids.len());
let system_pulls: Vec<_> = system_ids
.iter()
.map(|&system_id| {
println!("Spawning task to pull {} system if its missing", system_id);
tokio::spawn(pull_system_if_missing(
client.clone(),
graph.clone(),
system_id,
))
})
.collect();
futures::future::try_join_all(system_pulls).await?;

let mut set = JoinSet::new();

system_ids.iter().for_each(|&system_id| {
println!("Spawning task to pull {} system if its missing", system_id);
set.spawn(pull_system_if_missing(
client.clone(),
graph.clone(),
system_id,
));
});

while let Some(res) = set.join_next().await {
if let Err(e) = res.unwrap() {
return Err(e);
}
}

Ok(())
}

Expand All @@ -207,14 +214,24 @@ async fn pull_system_if_missing(
system_id: i64,
) -> Result<(), ReplicationError> {
println!("Checking if system_id {} exists in the database", system_id);
if !system_id_exists(graph.clone(), system_id).await? {
println!(
"System {} does not already exist in the database",
system_id
);
pull_system(client, graph.clone(), system_id).await?
let result = system_id_exists(graph.clone(), system_id).await;

match result {
Ok(exists) => {
if exists {
println!(
"System {} does not already exist in the database",
system_id
);
pull_system(client, graph.clone(), system_id).await?;
}
Ok(())
}
Err(_) => {
println!("Error checking if system_id {} exists", system_id);
Err(TargetError(Error::ConnectionError))
}
}
Ok(())
}

impl From<SystemEsiResponse> for System {
Expand Down Expand Up @@ -252,6 +269,8 @@ enum ReplicationError {
TargetError(#[from] neo4rs::Error),
}

impl Reject for ReplicationError {}

async fn pull_system(
client: Client,
graph: Arc<Graph>,
Expand Down

0 comments on commit 5a45fd3

Please sign in to comment.