Skip to content

Commit

Permalink
Handle statistics threads in AppState
Browse files Browse the repository at this point in the history
  • Loading branch information
moubctez committed Nov 22, 2024
1 parent 0ad5105 commit 1a77080
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 174 deletions.
8 changes: 4 additions & 4 deletions src-tauri/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 35 additions & 5 deletions src-tauri/src/appstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use std::{
sync::Arc,
};

use tauri::AppHandle;
use tauri::{
async_runtime::{spawn, JoinHandle},
AppHandle,
};
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
use tonic::transport::Channel;
Expand All @@ -18,7 +21,7 @@ use crate::{
service::{
proto::desktop_daemon_service_client::DesktopDaemonServiceClient, utils::setup_client,
},
utils::disconnect_interface,
utils::{disconnect_interface, stats_handler},
ConnectionType,
};

Expand All @@ -28,6 +31,7 @@ pub struct AppState {
pub client: DesktopDaemonServiceClient<Channel>,
pub log_watchers: Arc<std::sync::Mutex<HashMap<String, CancellationToken>>>,
pub app_config: Arc<std::sync::Mutex<AppConfig>>,
stat_threads: std::sync::Mutex<HashMap<Id, JoinHandle<()>>>, // location ID is the key
}

impl AppState {
Expand All @@ -40,6 +44,7 @@ impl AppState {
client,
log_watchers: Arc::new(std::sync::Mutex::new(HashMap::new())),
app_config: Arc::new(std::sync::Mutex::new(AppConfig::new(app_handle))),
stat_threads: std::sync::Mutex::new(HashMap::new()),
}
}

Expand All @@ -57,11 +62,29 @@ impl AppState {
interface_name: S,
connection_type: ConnectionType,
) {
let connection = ActiveConnection::new(location_id, interface_name.into(), connection_type);
let ifname = interface_name.into();
let connection = ActiveConnection::new(location_id, ifname.clone(), connection_type);
trace!("Adding active connection for location ID: {location_id}");
let mut connections = self.active_connections.lock().await;
connections.push(connection);
trace!("Current active connections: {connections:?}");

debug!("Spawning thread for network statistics for location ID {location_id}");
let handle = spawn(stats_handler(
self.get_pool(),
ifname,
connection_type,
self.client.clone(),
));
if let Some(old_handle) = self
.stat_threads
.lock()
.unwrap()
.insert(location_id, handle)
{
warn!("Something went wrong: old network statistics thread still exists");
old_handle.abort();
}
}

/// Try to remove a connection from the list of active connections.
Expand All @@ -72,6 +95,13 @@ impl AppState {
connection_type: &ConnectionType,
) -> Option<ActiveConnection> {
trace!("Removing active connection for location ID: {location_id}");

// Stop statistics thread
if let Some(handle) = self.stat_threads.lock().unwrap().get(&location_id) {
debug!("Stopping network statistics thread for {location_id}");
handle.abort();
}

let mut connections = self.active_connections.lock().await;

if let Some(index) = connections.iter().position(|conn| {
Expand All @@ -89,14 +119,14 @@ impl AppState {

pub(crate) async fn get_connection_id_by_type(
&self,
connection_type: &ConnectionType,
connection_type: ConnectionType,
) -> Vec<Id> {
let active_connections = self.active_connections.lock().await;

let connection_ids = active_connections
.iter()
.filter_map(|con| {
if con.connection_type.eq(connection_type) {
if con.connection_type == connection_type {
Some(con.location_id)
} else {
None
Expand Down
16 changes: 7 additions & 9 deletions src-tauri/src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use sqlx::{Sqlite, Transaction};
use struct_patch::Patch;
use tauri::{AppHandle, Manager, State};

static UPDATE_URL: &str = "https://pkgs.defguard.net/api/update/check";

use crate::{
app_config::{AppConfig, AppConfigPatch},
appstate::AppState,
Expand Down Expand Up @@ -306,7 +308,7 @@ pub async fn all_instances(app_state: State<'_, AppState>) -> Result<Vec<Instanc
trace!("Instances found: {instances:#?}");
let mut instance_info = Vec::new();
let connection_ids = app_state
.get_connection_id_by_type(&ConnectionType::Location)
.get_connection_id_by_type(ConnectionType::Location)
.await;
for instance in instances {
let locations = Location::find_by_instance_id(&app_state.get_pool(), instance.id).await?;
Expand Down Expand Up @@ -376,8 +378,8 @@ pub async fn all_locations(
"Found {} locations for instance {instance} to return information about.",
locations.len()
);
let active_locations_ids: Vec<i64> = app_state
.get_connection_id_by_type(&ConnectionType::Location)
let active_locations_ids = app_state
.get_connection_id_by_type(ConnectionType::Location)
.await;
let mut location_info = Vec::new();
for location in locations {
Expand Down Expand Up @@ -962,7 +964,7 @@ pub async fn all_tunnels(app_state: State<'_, AppState>) -> Result<Vec<TunnelInf
trace!("Tunnels found: {tunnels:#?}");
let mut tunnel_info = Vec::new();
let active_tunnel_ids = app_state
.get_connection_id_by_type(&ConnectionType::Tunnel)
.get_connection_id_by_type(ConnectionType::Tunnel)
.await;

for tunnel in tunnels {
Expand Down Expand Up @@ -1082,11 +1084,7 @@ pub async fn get_latest_app_version(handle: AppHandle) -> Result<AppVersionInfo,
debug!("Fetching latest application version, client metadata: current version: {app_version} and operating system: {operating_system}");

let client = reqwest::Client::new();
let res = client
.post("https://pkgs.defguard.net/api/update/check")
.json(&request_data)
.send()
.await;
let res = client.post(UPDATE_URL).json(&request_data).send().await;

if let Ok(response) = res {
let response_json = response.json::<AppVersionInfo>().await;
Expand Down
2 changes: 1 addition & 1 deletion src-tauri/src/enterprise/periodic/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
};

const INTERVAL_SECONDS: Duration = Duration::from_secs(30);
const POLLING_ENDPOINT: &str = "/api/v1/poll";
static POLLING_ENDPOINT: &str = "/api/v1/poll";

/// Periodically retrieves and updates configuration for all [`Instance`]s.
/// Updates are only performed if no connections are established to the [`Instance`],
Expand Down
6 changes: 3 additions & 3 deletions src-tauri/src/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,16 @@ impl DesktopDaemonService for DaemonService {
Status::new(Code::Internal, msg)
})?;

if !dns.is_empty() {
if dns.is_empty() {
debug!("No DNS configuration provided for interface {ifname}, skipping DNS configuration");
} else {
debug!("The following DNS servers will be set: {dns:?}, search domains: {search_domains:?}");
wgapi.configure_dns(&dns, &search_domains).map_err(|err| {
let msg =
format!("Failed to configure DNS for WireGuard interface {ifname}: {err}");
error!("{msg}");
Status::new(Code::Internal, msg)
})?;
} else {
debug!("No DNS configuration provided for interface {ifname}, skipping DNS configuration");
}
}

Expand Down
4 changes: 2 additions & 2 deletions src-tauri/src/tray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ async fn handle_location_tray_menu(id: String, handle: &AppHandle) {
Ok(location_id) => {
match Location::find_by_id(&handle.state::<AppState>().get_pool(), location_id).await {
Ok(Some(location)) => {
let active_locations_ids: Vec<i64> = handle
let active_locations_ids = handle
.state::<AppState>()
.get_connection_id_by_type(&ConnectionType::Location)
.get_connection_id_by_type(ConnectionType::Location)
.await;

if active_locations_ids.contains(&location_id) {
Expand Down
Loading

0 comments on commit 1a77080

Please sign in to comment.