"""Load balancing for MPI-distributed PIC simulations."""
from collections.abc import Callable
import numpy as np
from numpy.typing import NDArray
from ..mpi.mpi_manager import MPIManager
from ..patch.patch import Patch, Patches
from ..utils.enable_mixin import EnableMixin
from ..utils.logger import logger, rank_log
[docs]
class LoadBalancer(EnableMixin):
"""Handles dynamic load balancing of patches across MPI ranks.
"""
def __init__(
self,
patches: Patches,
mpi: MPIManager,
threshold: float = 0.1,
load_function: Callable[[Patch], float] | None = None,
) -> None:
"""Initialize with patches (called during simulation init).
Parameters
----------
patches : Patches
The patches to be load balanced.
mpi : MPIManager
The MPI manager for communication.
threshold : float, optional
Load imbalance threshold for triggering rebalance.
Rebalance is triggered when (max_load - min_load) / avg_load > threshold.
load_function : callable | None, optional
Custom function to calculate load for a single patch. The function
should accept a `Patch` as its only parameter and return a float
representing the load. If None, uses the default load calculation.
"""
self.patches = patches
self.dimension = patches.dimension
self.comm = mpi.comm
self.rank = mpi.rank
self.comm_size = mpi.size
self.load_function = load_function or LoadBalancer._default_load_function
self.local_loads = np.zeros(len(patches), dtype=np.float64)
self.threshold = threshold
self._init_threshold = threshold
self.dec_factor = 3/np.pi
self.inc_factor = np.e/2
def __call__(self) -> None:
"""Execute rebalance."""
if self.comm_size == 1:
return
_, global_loads = self._gather_loads()
patches_list, index_to_new_rank = self._compute_distribution(global_loads)
patches_new: Patches = self.comm.scatter(patches_list, root=0)
index_to_new_rank = self.comm.bcast(index_to_new_rank, root=0)
self._exchange(patches_new, index_to_new_rank)
self._finalize(patches_new)
self.update_weights()
if self.should_rebalance():
self.threshold *= self.inc_factor
if self.rank == 0:
logger.info(f"still unbalanced after rebalance, "
f"increasing threshold to {self.threshold:.2f}", self.comm)
elif self.threshold > self._init_threshold:
self.threshold *= self.dec_factor
if self.rank == 0:
logger.info(f"balanced after rebalance, "
f"decreasing threshold to {self.threshold:.2f}", self.comm)
@staticmethod
def _default_load_function(patch: Patch) -> float:
"""Default load calculation for a single patch.
Load calculation:
- 2D: load = npart + nx*ny/2
- 3D: load = npart + nx*ny*nz/2
"""
load = 0.0
for part in patch.particles:
load += part._npart_alive
if hasattr(patch, 'nz'):
load += patch.nx * patch.ny * patch.nz / 2
else:
load += patch.nx * patch.ny / 2
return load
def _gather_loads(self) -> tuple[NDArray, NDArray | None]:
"""Gather load information from all MPI ranks."""
local_indices = np.array([p.index for p in self.patches], dtype=np.int64)
all_loads = self.comm.gather(self.local_loads, root=0)
all_indices = self.comm.gather(local_indices, root=0)
global_loads = None
if self.rank == 0:
assert all_loads is not None
assert all_indices is not None
npatches_total = sum(len(loads) for loads in all_loads)
global_loads = np.zeros(npatches_total, dtype=np.float64)
for loads, indices in zip(all_loads, all_indices):
for load, idx in zip(loads, indices):
global_loads[idx] = load
return local_indices, global_loads
def _compute_distribution(
self,
global_loads: NDArray | None,
) -> tuple[list[Patches] | None, dict | None]:
"""Compute new patch distribution."""
from ..patch.metis import compute_rank
all_patches_skeleton = self.comm.gather(
[p.copy_skeleton() for p in self.patches], root=0
)
patches_list = None
index_to_new_rank = None
if self.rank == 0:
assert all_patches_skeleton is not None
assert global_loads is not None
patches_all = Patches(self.dimension)
for patch in sorted(
(patch for patches in all_patches_skeleton for patch in patches),
key=lambda patch: patch.index,
):
patches_all.append(patch)
weights = global_loads.astype(np.int64)
new_ranks, _ = compute_rank(
patches_all,
nrank=self.comm_size,
weights=weights,
rank_prev=np.array([p.rank for p in patches_all]),
)
for p, new_rank in zip(patches_all.patches, new_ranks):
p.rank = new_rank
if self.dimension == 2:
patches_all.init_neighbor_rank_2d()
else:
patches_all.init_neighbor_rank_3d()
patches_list = [Patches(self.dimension) for _ in range(self.comm_size)]
index_to_new_rank = {}
for p in patches_all.patches:
assert p.rank is not None
patches_list[p.rank].append(p)
index_to_new_rank[p.index] = p.rank
return patches_list, index_to_new_rank
def _exchange(
self,
patches_new: Patches,
index_to_new_rank: dict,
) -> None:
"""Exchange patches between ranks."""
import dill as pickle
from mpi4py import MPI
patches = self.patches
comm = self.comm
rank = self.rank
old_patch_indices = {p.index: p for p in patches}
new_patch_indices = {p.index: p for p in patches_new}
patches_to_send_idx = [
idx for idx in old_patch_indices if idx not in new_patch_indices
]
patches_to_recv_idx = [
idx for idx in new_patch_indices if idx not in old_patch_indices
]
logger.debug(f"patches to send: {patches_to_send_idx}", comm)
logger.debug(f"patches to receive: {patches_to_recv_idx}", comm)
all_locations = comm.allgather({p.index: rank for p in patches})
index_to_old_rank = {}
for loc in all_locations:
index_to_old_rank.update(loc)
requests = []
patches_to_send = []
for idx in patches_to_send_idx:
patches_to_send.append(patches.pop(idx))
for p in patches_to_send:
target_rank = index_to_new_rank[p.index]
data = pickle.dumps(p, byref=True, recurse=True)
logger.debug(f"sending patch {p.index} to rank {target_rank}", comm)
req = comm.isend(data, dest=target_rank, tag=p.index)
requests.append(req)
for idx in patches_to_recv_idx:
source_rank = index_to_old_rank[idx]
if source_rank == rank:
continue
logger.debug(f"receiving patch {idx} from rank {source_rank}", comm)
data = comm.recv(source=source_rank, tag=idx)
p = pickle.loads(data)
p.rank = rank
patches.append(p)
MPI.Request.waitall(requests)
def _finalize(self, patches_new: Patches) -> None:
"""Finalize rebalance."""
new_patch_indices = {p.index: p for p in patches_new}
patches = self.patches
for p in patches:
p.neighbor_rank[:] = new_patch_indices[p.index].neighbor_rank[:]
if self.dimension == 2:
patches.init_neighbor_ipatch_2d()
else:
patches.init_neighbor_ipatch_3d()
def update_weights(self) -> None:
if self.local_loads.size != self.patches.npatches:
self.local_loads = np.zeros(self.patches.npatches, dtype=np.float64)
logger.debug(f"Rank {self.rank}: resized local_loads to {self.patches.npatches}")
for ipatch, p in enumerate(self.patches):
self.local_loads[ipatch] = self.load_function(p)
logger.debug(f"Rank {self.rank}: patch {p.index} load: {self.local_loads[ipatch]}")
def should_rebalance(self) -> bool:
if self.local_loads is None or len(self.local_loads) == 0:
return False
# Calculate local total weight (sum of all patch weights on this rank)
local_total = float(np.sum(self.local_loads))
# Gather total weights from all MPI ranks
all_totals = self.comm.allgather(local_total)
all_totals = np.array(all_totals, dtype=np.float64)
# Check global load imbalance across ranks
max_load = np.max(all_totals)
min_load = np.min(all_totals)
avg_load = np.mean(all_totals)
if avg_load == 0:
return False
# print(f"Rank {self.rank}: {max_load=}, {min_load=}, {avg_load=}, {(max_load - min_load) / avg_load = }")
if (max_load - min_load) / avg_load > self.threshold:
return True
return False