Skip to content

Commit

Permalink
impl tray menu instance submenu (#289)
Browse files Browse the repository at this point in the history
* impl tray menu instance submenu

* refactor handle_tray_event

* Apply suggestions from code review

Co-authored-by: Adam <[email protected]>

* Apply suggestions from code review part 2

Co-authored-by: Adam <[email protected]>

* Apply suggestions from code review part 3

* add mfa trigger event

---------

Co-authored-by: Adam <[email protected]>
  • Loading branch information
cpprian and moubctez authored Sep 4, 2024
1 parent 8b00899 commit f2f7310
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 32 deletions.
10 changes: 5 additions & 5 deletions src-tauri/src/bin/defguard-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use defguard_client::{
},
database::{self, models::settings::Settings},
latest_app_version::fetch_latest_app_version_loop,
tray::{configure_tray_icon, create_tray_menu, handle_tray_event},
tray::{configure_tray_icon, handle_tray_event, reload_tray_menu},
utils::load_log_targets,
};
use std::{env, str::FromStr};
Expand Down Expand Up @@ -70,9 +70,6 @@ async fn main() {
debug!("Added binary dir {current_bin_dir:?} to PATH");
}

let tray_menu = create_tray_menu();
let system_tray = SystemTray::new().with_menu(tray_menu);

let log_level =
LevelFilter::from_str(&env::var("DEFGUARD_CLIENT_LOG_LEVEL").unwrap_or("info".into()))
.unwrap_or(LevelFilter::Info);
Expand Down Expand Up @@ -114,7 +111,7 @@ async fn main() {
}
_ => {}
})
.system_tray(system_tray)
.system_tray(SystemTray::new())
.on_system_tray_event(handle_tray_event)
.plugin(tauri_plugin_single_instance::init(|app, argv, cwd| {
let _ = app.emit_all("single-instance", Payload { args: argv, cwd });
Expand Down Expand Up @@ -163,6 +160,9 @@ async fn main() {

tauri::async_runtime::spawn(fetch_latest_app_version_loop(app_handle.clone()));

// load tray menu after database initialization to show all instance and locations
reload_tray_menu(&app_handle).await;

// Handle Ctrl-C
tauri::async_runtime::spawn(async move {
tokio::signal::ctrl_c()
Expand Down
6 changes: 4 additions & 2 deletions src-tauri/src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
error::Error,
proto::{DeviceConfig, DeviceConfigResponse},
service::{log_watcher::stop_log_watcher_task, proto::RemoveInterfaceRequest},
tray::configure_tray_icon,
tray::{configure_tray_icon, reload_tray_menu},
utils::{
disconnect_interface, get_location_interface_details, get_tunnel_interface_details,
handle_connection_for_location, handle_connection_for_tunnel,
Expand Down Expand Up @@ -39,7 +39,8 @@ pub async fn connect(
let state = handle.state::<AppState>();
if connection_type.eq(&ConnectionType::Location) {
if let Some(location) = Location::find_by_id(&state.get_pool(), location_id).await? {
handle_connection_for_location(&location, preshared_key, handle).await?;
handle_connection_for_location(&location, preshared_key, handle.clone()).await?;
reload_tray_menu(&handle).await;
} else {
error!("Location {location_id} not found");
return Err(Error::NotFound);
Expand Down Expand Up @@ -79,6 +80,7 @@ pub async fn disconnect(
)?;
stop_log_watcher_task(&handle, &interface_name)?;
info!("Disconnected from location with id: {location_id}");
reload_tray_menu(&handle).await;
Ok(())
} else {
error!("Error while disconnecting from location with id: {location_id} not found");
Expand Down
145 changes: 125 additions & 20 deletions src-tauri/src/tray.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,87 @@
use tauri::{
AppHandle, CustomMenuItem, Icon, Manager, State, SystemTrayEvent, SystemTrayMenu,
SystemTrayMenuItem,
SystemTrayMenuItem, SystemTraySubmenu,
};

use crate::{appstate::AppState, database::TrayIconTheme, error::Error};
use crate::{
appstate::AppState,
commands::{all_instances, all_locations, connect, disconnect},
database::{Location, TrayIconTheme},
error::Error,
ConnectionType,
};

static SUBSCRIBE_UPDATES_LINK: &str = "https://defguard.net/newsletter";
static JOIN_COMMUNITY_LINK: &str = "https://matrix.to/#/#defguard:teonite.com";
static FOLLOW_US_LINK: &str = "https://floss.social/@defguard";

#[must_use]
pub fn create_tray_menu() -> SystemTrayMenu {
pub async fn generate_tray_menu(app_state: State<'_, AppState>) -> Result<SystemTrayMenu, Error> {
let quit = CustomMenuItem::new("quit", "Quit");
let show = CustomMenuItem::new("show", "Show");
let hide = CustomMenuItem::new("hide", "Hide");
let subscribe_updates = CustomMenuItem::new("subscribe_updates", "Subscribe for updates");
let join_community = CustomMenuItem::new("join_community", "Join our Community");
let follow_us = CustomMenuItem::new("follow_us", "Follow us");
SystemTrayMenu::new()
let mut tray_menu = SystemTrayMenu::new();

// INSTANCE SECTION
info!("Load all instances for tray menu");
let all_instances = all_instances(app_state.clone()).await;
debug!("All instances: {:?}", all_instances);
if let Ok(instances) = all_instances {
for instance in instances {
let mut instance_menu = SystemTrayMenu::new();
let all_locations = all_locations(
instance.id.expect("Missing instannce id"),
app_state.clone(),
)
.await
.unwrap();
debug!(
"All locations {:?} in instance {:?}",
all_locations, instance
);

// TODO: apply icons instead of Connect/Disconnect when defguard utilizes tauri v2
for location in all_locations {
let item_name = if location.active {
format!("Disconnect: {}", location.name)
} else {
format!("Connect: {}", location.name)
};
instance_menu =
instance_menu.add_item(CustomMenuItem::new(location.id.to_string(), item_name));
debug!("Added new menu item for {:?}", location);
}
tray_menu = tray_menu.add_submenu(SystemTraySubmenu::new(instance.name, instance_menu));
}
} else if let Err(err) = all_instances {
warn!("Cannot load instance menu: {:?}", err);
}

// Load rest of tray menu options
tray_menu = tray_menu
.add_native_item(SystemTrayMenuItem::Separator)
.add_item(show)
.add_item(hide)
.add_native_item(SystemTrayMenuItem::Separator)
.add_item(subscribe_updates)
.add_item(join_community)
.add_item(follow_us)
.add_native_item(SystemTrayMenuItem::Separator)
.add_item(quit)
.add_item(quit);

info!("Successfully sets tray menu");
Ok(tray_menu)
}

pub async fn reload_tray_menu(app_handle: &AppHandle) {
let system_menu = generate_tray_menu(app_handle.state::<AppState>())
.await
.unwrap();
if let Err(err) = app_handle.tray_handle().set_menu(system_menu) {
warn!("Unable to update tray menu {err:?}");
}
}

fn show_main_window(app: &AppHandle) {
Expand All @@ -46,18 +102,9 @@ fn show_main_window(app: &AppHandle) {

// handle tray actions
pub fn handle_tray_event(app: &AppHandle, event: SystemTrayEvent) {
match event {
SystemTrayEvent::LeftClick { .. } => {
if let Some(main_window) = app.get_window("main") {
let visible = main_window.is_visible().unwrap_or_default();
if visible {
let _ = main_window.hide();
} else {
show_main_window(app);
}
}
}
SystemTrayEvent::MenuItemClick { id, .. } => match id.as_str() {
let handle = app.clone();
if let SystemTrayEvent::MenuItemClick { id, .. } = event {
match id.as_str() {
"quit" => {
info!("Received QUIT request. Initiating shutdown...");
let app_state: State<AppState> = app.state();
Expand All @@ -80,9 +127,13 @@ pub fn handle_tray_event(app: &AppHandle, event: SystemTrayEvent) {
"follow_us" => {
let _ = webbrowser::open(FOLLOW_US_LINK);
}
_ if id.chars().all(char::is_numeric) => {
tauri::async_runtime::spawn(async move {
handle_location_tray_menu(id, &handle).await;
});
}
_ => {}
},
_ => {}
}
}
}

Expand All @@ -99,3 +150,57 @@ pub fn configure_tray_icon(app: &AppHandle, theme: &TrayIconTheme) -> Result<(),
Err(Error::ResourceNotFound(resource_str))
}
}

#[derive(Clone, serde::Serialize)]
struct Payload {
message: String,
}

async fn handle_location_tray_menu(id: String, handle: &AppHandle) {
match id.parse::<i64>() {
Ok(location_id) => {
match Location::find_by_id(&handle.state::<AppState>().get_pool(), location_id).await {
Ok(Some(location)) => {
let active_locations_ids: Vec<i64> = handle
.state::<AppState>()
.get_connection_id_by_type(&ConnectionType::Location)
.await;

if active_locations_ids.contains(&location_id) {
info!("Disconnect location with id {}", id);
let _ =
disconnect(location_id, ConnectionType::Location, handle.clone()).await;
} else {
info!("Connect location with id {}", id);
// check is mfa enabled and trigger modal on frontend
if location.mfa_enabled {
info!(
"mfa enabled for location with id {:?}, trigger mfa modal",
location.id.expect("Missing location id")
);
handle
.emit_all(
"mfa-trigger",
Payload {
message: "Trigger mfa event".into(),
},
)
.unwrap();
} else {
let _ = connect(
location_id,
ConnectionType::Location,
Some(location.pubkey),
handle.clone(),
)
.await;
}
}
}
Ok(None) => warn!("Location does not exist"),
Err(e) => warn!("Unable to find location: {e:?}"),
};
}
Err(e) => warn!("Can't handle event due to: {e:?}"),
}
}
3 changes: 1 addition & 2 deletions src-tauri/tauri.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
"tauri": {
"systemTray": {
"iconPath": "resources/icons/tray-32x32-color.png",
"iconAsTemplate": false,
"menuOnLeftClick": false
"iconAsTemplate": false
},
"allowlist": {
"all": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ export const AddInstanceInitForm = ({ nextStep }: Props) => {
})
.catch((e) => {
setIsLoading(false);
console.log(e);
if (typeof e === 'string') {
if (e.includes('Network Error')) {
toaster.error(LL.common.messages.networkError());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import './style.scss';

import { listen } from '@tauri-apps/api/event';
import classNames from 'classnames';
import { useState } from 'react';
import { useEffect, useState } from 'react';
import { error } from 'tauri-plugin-log-api';

import { useI18nContext } from '../../../../../../../../i18n/i18n-react';
Expand All @@ -23,6 +24,10 @@ type Props = {
location?: CommonWireguardFields;
};

type Payload = {
location?: CommonWireguardFields;
};

export const LocationCardConnectButton = ({ location }: Props) => {
const toaster = useToaster();
const [isLoading, setIsLoading] = useState(false);
Expand Down Expand Up @@ -62,6 +67,17 @@ export const LocationCardConnectButton = ({ location }: Props) => {
}
};

useEffect(() => {
async function listenMFAEvent() {
await listen<Payload>('mfa-trigger', () => {
if (location) {
openMFAModal(location);
}
});
}
listenMFAEvent();
}, [openMFAModal, location]);

return (
<Button
onClick={handleClick}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ export const StatsLayoutSelect = ({ locations }: StatsLayoutSelect) => {

const renderSelected: SelectProps<ClientView>['renderSelected'] = useCallback(
(value): SelectSelectedValue => {
console.log(locations);
const selected = options.find((o) => o.value === value);
if (selected) {
return {
Expand Down

0 comments on commit f2f7310

Please sign in to comment.