Skip to content

Commit

Permalink
Solve several bugs to get the two step load balancing functionality t…
Browse files Browse the repository at this point in the history
…o work
  • Loading branch information
IshaanDesai committed Jan 31, 2025
1 parent 7e4aa0d commit 8ef1b03
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 101 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/run-adaptivity-tests-parallel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ jobs:
working-directory: micro-manager/tests/unit
run: mpiexec -n 2 --allow-run-as-root python3 -m unittest test_adaptivity_parallel.py

- name: Run load balancing unit tests
- name: Run load balancing unit tests with 2 ranks
working-directory: micro-manager/tests/unit
run: mpiexec -n 2 --allow-run-as-root python3 -m unittest test_global_adaptivity_lb.py

- name: Run load balancing tests with 4 ranks
working-directory: micro-manager/tests/unit
run: mpiexec -n 4 --allow-run-as-root --oversubscribe python3 -m unittest test_global_adaptivity_lb.py
41 changes: 10 additions & 31 deletions micro_manager/adaptivity/global_adaptivity_lb.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self._local_number_of_sims = len(global_ids)

self._is_load_balancing_done_in_two_steps = (
configurator.get_two_step_load_balancing()
configurator.is_load_balancing_two_step()
)

def redistribute_sims(self, micro_sims: list) -> None:
Expand Down Expand Up @@ -129,6 +129,7 @@ def _redistribute_active_sims(self, micro_sims: list) -> None:

if excess_recv_sims == 0:
break

elif n_global_send_sims > n_global_recv_sims:
excess_send_sims = n_global_send_sims - n_global_recv_sims
while excess_send_sims > 0:
Expand Down Expand Up @@ -165,35 +166,15 @@ def _redistribute_inactive_sims(self, micro_sims):
ranks_of_sims = self._get_ranks_of_sims()

global_ids_of_inactive_sims = np.where(self._is_sim_active == False)[0]
global_ids_of_active_sims = np.where(self._is_sim_active)[0]

current_ranks_of_active_sims = []
associated_inactive_sims = (
dict()
) # Keys are global IDs of active sims, values are lists of global IDs of inactive sims associated to them

for active_gid in global_ids_of_active_sims:
current_ranks_of_active_sims.append(ranks_of_sims[active_gid])

associated_inactive_sims[active_gid] = [
i for i, x in enumerate(self._sim_is_associated_to) if x == active_gid
]

current_ranks_of_inactive_sims = []
new_ranks_of_inactive_sims = []
for inactive_gid in global_ids_of_inactive_sims:
current_ranks_of_inactive_sims.append(ranks_of_sims[inactive_gid])
assoc_active_gid = self._sim_is_associated_to[inactive_gid]

new_ranks_of_inactive_sims = current_ranks_of_inactive_sims.copy()
current_ranks_of_inactive_sims.append(ranks_of_sims[inactive_gid])

for active_gid, assoc_inactive_gids in associated_inactive_sims.items():
for inactive_gid in assoc_inactive_gids:
inactive_idx = np.where(global_ids_of_inactive_sims == inactive_gid)[0][
0
]
active_idx = np.where(global_ids_of_active_sims == active_gid)[0][0]
new_ranks_of_inactive_sims[inactive_idx] = current_ranks_of_active_sims[
active_idx
]
new_ranks_of_inactive_sims.append(ranks_of_sims[assoc_active_gid])

# keys are global IDs of sim states to send, values are ranks to send the sims to
send_map: dict[int, int] = dict()
Expand All @@ -215,12 +196,10 @@ def _get_communication_maps(self, global_send_sims, global_recv_sims):
"""
...
"""

global_ids_of_active_sims_local = list(
np.where(
self._is_sim_active[self._global_ids[0] : self._global_ids[-1] + 1]
)[0]
)
global_ids_of_active_sims_local = []
for global_id in self._global_ids:
if self._is_sim_active[global_id] == True:
global_ids_of_active_sims_local.append(global_id)

rank_wise_global_ids_of_active_sims = self._comm.allgather(
global_ids_of_active_sims_local
Expand Down
21 changes: 18 additions & 3 deletions micro_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,13 @@ def read_json_micro_manager(self):
)

try:
self._two_step_load_balancing = self._data["simulation_params"][
"adaptivity_settings"
]["two_step_load_balancing"]
if (
self._data["simulation_params"]["adaptivity_settings"][
"two_step_load_balancing"
]
== "True"
):
self._two_step_load_balancing = True
except BaseException:
self._logger.log_info_one_rank(
"Two-step load balancing is not specified. Micro Manager will only try to balance the load in one sweep." # TODO: Need a better log message here.
Expand Down Expand Up @@ -616,6 +620,17 @@ def get_load_balancing_n(self):
"""
return self._load_balancing_n

def is_load_balancing_two_step(self):
"""
Check if two-step load balancing is required.
Returns
-------
two_step_load_balancing : bool
True if two-step load balancing is required, False otherwise.
"""
return self._two_step_load_balancing

def get_micro_dt(self):
"""
Get the size of the micro time window.
Expand Down
7 changes: 0 additions & 7 deletions micro_manager/micro_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,15 +909,8 @@ def _solve_micro_simulations_with_adaptivity(
if self._is_micro_solve_time_required:
micro_sims_output[inactive_id]["solve_cpu_time"] = 0

print(
"Rank {}: data_for_adaptivity = {}".format(
self._rank, self._data_for_adaptivity
)
)

# Collect micro sim output for adaptivity calculation
for i in range(self._local_number_of_sims):
print("Rank {}: i = {}".format(self._rank, i))
for name in self._adaptivity_micro_data_names:
self._data_for_adaptivity[name][i] = micro_sims_output[i][name]

Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_adaptivity_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,24 @@ def test_communicate_micro_output(self):
)

self.assertTrue(np.array_equal(expected_sim_output, sim_output))

def test_get_ranks_of_sims(self):
""" """
if self._rank == 0:
global_ids = [0, 1, 2]
expected_ranks_of_sims = [0, 0, 0, 1, 1]
elif self._rank == 1:
global_ids = [3, 4]
expected_ranks_of_sims = [0, 0, 0, 1, 1]

adaptivity_controller = GlobalAdaptivityCalculator(
self._configurator,
self._global_number_of_sims,
global_ids,
rank=self._rank,
comm=self._comm,
)

actual_ranks_of_sims = adaptivity_controller._get_ranks_of_sims()

self.assertTrue(np.array_equal(expected_ranks_of_sims, actual_ranks_of_sims))
Loading

0 comments on commit 8ef1b03

Please sign in to comment.