fromfunctoolsimportwrapsfromtypingimportCallable,Optionalfromyaspinimportyaspinfrom..core.utils.loggerimportloggerfrom..core.utils.terminalimportis_terminalfrom..core.utils.timerimportTimerfrom..simulationimportSimulationdef_validate_interval(interval:int|float|Callable)->None:ifnotisinstance(interval,(int,float,Callable)):raiseTypeError(f"Invalid interval: {interval}. Must be int, float, or Callable")ifisinstance(interval,float):ifinterval<=0orinterval>=1:raiseValueError(f"Invalid interval: {interval}. Must be between 0 and 1s if it is a float")ifisinstance(interval,int)andinterval<1:raiseValueError(f"Invalid interval: {interval}. Must be greater than 0 if it is an integer")def_interval_triggered(sim:Simulation,interval:int|float|Callable)->bool:ifcallable(interval):returnbool(interval(sim))ifisinstance(interval,int):returnsim.itime%interval==0ifisinstance(interval,float):time_value=getattr(sim,"time",None)iftime_valueisNone:raiseAttributeError("Simulation instance must provide `time` when using float interval callbacks.")return(time_value%interval)<sim.dtreturnTrue
[docs]defcallback(stage:Optional[str]=None,interval:int|float|Callable=1)->Callable:""" A decorator for implementing callbacks in PIC simulations. This decorator allows functions to be attached to specific simulation stages, enabling dynamic behavior modification without changing the core simulation code. Args: stage: The simulation stage at which this callback should be executed. Defaults to ``Simulation.default_callback_stage()`` when not specified. interval (int|float|Callable): if int, The number of iterations between calls to the callback function. if float, The time interval in seconds between calls to the callback function. if Callable, The function to determine whether to call the callback function. The function should take a Simulation object as an argument and return a boolean value. Defaults to 1 (call every iteration). Returns: Callable: The decorated callable object (an instance of a Callback subclass). Example: >>> @callback(stage="maxwell_1", interval=100) ... def custom_field_modification(sim): ... for patch in sim.patches: ... patch.fields.ex *= 1.1 # Amplify Ex field by 10% """defdecorator(func:Callable)->Callable:_validate_interval(interval)@wraps(func)defwrapper(*args,**kwargs):sim=args[-1]ifnot_interval_triggered(sim,interval):returnifsim.mpi.rank==0:ifis_terminal():withyaspin(text=f"Running callback: {func.__name__}")assp:withTimer(f"callback: {func.__name__}"):ret=func(*args,**kwargs)else:logger.info(f"Running callback: {func.__name__}")withTimer(f"callback: {func.__name__}"):ret=func(*args,**kwargs)sim.mpi.comm.Barrier()else:withTimer(f"callback: {func.__name__}"):ret=func(*args,**kwargs)sim.mpi.comm.Barrier()returnret# Add stage attribute and execute methodwrapper.stage=stagereturnwrapperreturndecorator
classCallback:"""A base class for implementing callbacks in PIC simulations."""interval:int|float|Callablestage:strdef__call__(self,sim:Simulation):_validate_interval(self.interval)ifnot_interval_triggered(sim,self.interval):returnifsim.mpi.rank==0:ifis_terminal():withyaspin(text=f"Running callback: {self.__class__.__name__}")assp:withTimer(f"callback: {self.__class__.__name__}"):ret=self._call(sim)else:logger.info(f"Running callback: {self.__class__.__name__}")withTimer(f"callback: {self.__class__.__name__}"):ret=self._call(sim)sim.mpi.comm.Barrier()else:withTimer(f"callback: {self.__class__.__name__}"):ret=self._call(sim)sim.mpi.comm.Barrier()returnretdef_call(self,sim:Simulation):raiseNotImplementedError