[docs]classRestartDump(Callback):"""Callback that persists per-rank restart checkpoints for later replay. The callback runs at stage ``"end"`` and captures one shard per MPI rank inside ``out_dir/ckpt_<itime>/``. Each shard stores the full ``Simulation`` state so a subsequent :meth:`RestartDump.load` call can resume execution on the same rank topology. Parameters ---------- out_dir : str | Path Root directory that will hold checkpoint folders. interval : int | float | Callable Dump cadence; accepts step counts, wall-clock seconds, or a predicate callable that mirrors the base callback interval semantics. keep : int | None Number of most recent checkpoints to retain. When set, rank 0 trims older directories after a successful dump. dump_signals : Sequence[int] | bool POSIX signals that trigger an immediate checkpoint; ``True`` registers ``SIGINT`` and ``SIGTERM``, ``False`` disables signal-triggered dumps. Attributes ---------- stage : str Simulation stage where the callback executes. Examples -------- >>> sim = Simulation2D(...) >>> sim.run(callbacks=[RestartDump('checkpoints', interval=100)]) # To restart the simulation, before calling the sim.run # replace the sim instance with the loaded >>> sim = RestartDump.load('checkpoints/ckpt_000100') >>> sim.run(callbacks=...) # Continue from checkpoint Note ---- Setting ``dump_signals`` allows automatic checkpointing when simulation is stopped by time limit of job scheduler like slurm. """DEFAULT_STAGE="end"def__init__(self,out_dir:Union[str,Path],interval:Union[int,float,Callable]=1000,keep:Optional[int]=None,dump_signals:list[int]|bool=False)->None:self.stage=self.DEFAULT_STAGEself.out_dir=Path(out_dir)self.interval=intervalself.keep=keepself.out_dir.mkdir(parents=True,exist_ok=True)ifdump_signalsisFalse:self.dump_signals=[]elifdump_signalsisTrue:self.dump_signals=[signal.SIGINT,signal.SIGTERM]else:self.dump_signals=dump_signalsforsiginself.dump_signals:signal.signal(sig,self._dump_handler)self._dump_requested=Falsedef_dump_handler(self,sig,frame):self._dump_requested=True# ---------------------- save path helpers ----------------------def_ckpt_dir(self,itime:int)->Path:returnself.out_dir/f"ckpt_{itime:06d}"def_rank_shard_path(self,itime:int,rank:int)->Path:returnself._ckpt_dir(itime)/f"rank_{rank:06d}.pkl"# ---------------------- callback entry ----------------------def_call(self,sim:Union[Simulation,Simulation3D]):comm=sim.mpi.commrank=sim.mpi.rankckpt_dir=self._ckpt_dir(sim.itime)ifrank==0:ckpt_dir.mkdir(parents=True,exist_ok=True)comm.Barrier()# All ranks write shardswithopen(self._rank_shard_path(sim.itime,rank),"wb")asf:pickle.dump(sim,f,byref=True,recurse=True)comm.Barrier()# Optionally trim old checkpoints (rank 0 only)ifrank==0andself.keepisnotNoneandself.keep>0:self._gc_old_checkpoints(self.keep)comm.Barrier()def_gc_old_checkpoints(self,keep:int)->None:# Keep most recent N ckpt_* directoriessubdirs=[dfordinself.out_dir.iterdir()ifd.is_dir()andd.name.startswith("ckpt_")]subdirs.sort(key=lambdap:p.name)iflen(subdirs)<=keep:returnto_delete=subdirs[:len(subdirs)-keep]fordinto_delete:try:# remove directory recursivelyforpathinsorted(d.rglob("*"),key=lambdap:len(p.parts),reverse=True):ifpath.is_file():path.unlink(missing_ok=True)elifpath.is_dir():path.rmdir()d.rmdir()exceptExceptionase:logger.warning(f"Failed to remove old checkpoint {d}: {e}")# ---------------------- loader ----------------------
[docs]@staticmethoddefload(ckpt_dir:Union[str,Path],comm=None)->Union[Simulation,Simulation3D]:"""Load a Simulation from a RestartDump checkpoint directory. Parameters: ckpt_dir(str|Path): Path to a single checkpoint directory (ckpt_xxxxxx). comm(mpi4py.MPI.Comm): Optional MPI communicator to use. Returns: Simulation or Simulation3D instance restored to the checkpoint state. """ckpt_dir=Path(ckpt_dir)ifcommisNone:comm=MPIManager.get_default_comm()rank=comm.Get_rank()shard_path=ckpt_dir/f"rank_{rank:06d}.pkl"withopen(shard_path,"rb")asf:sim=pickle.load(f)sim.update_lists()# inc by 1, since restart is called before itime incsim.itime+=1comm.Barrier()logger.info(f"Rank {rank}: Checkpoint loaded from {ckpt_dir}, itime={sim.itime}")returnsim