Writing your own callbacks¶
The @callback decorator¶
The @callback decorator adds stage and interval attributes to the function.
The stage specifies when the callback should be called. Leave it empty to run the callback at the end of each loop.
The interval specifies how frequent the callback should be called. It can be a number or a function that returns a boolean.
interval = lambda sim: sim.itime == 42 means the callback will be called at the 42nd iteration.
By passing to the sim.run(1000, callbacks=[your callbacks]), they will be sequentially called by the Simulation.run method.
Note
The @callback decorator and Callback base class detect whether the process
is running in a terminal (via is_terminal()). In an interactive terminal, a yaspin
spinner is displayed during callback execution. In non-terminal environments (e.g. batch
jobs, pipes, or log files), callback execution is logged via the logger instead. Built-in
callbacks inheriting from Callback share this terminal-aware behavior.
- class lambdapic.callback.callback.callback(stage: str | None = None, interval: int | float | Callable = 1)[source]¶
A decorator for implementing callbacks in PIC simulations.
This decorator allows functions to be attached to specific simulation stages, enabling dynamic behavior modification without changing the core simulation code.
- Parameters:
stage – The simulation stage at which this callback should be executed. Defaults to
Simulation.default_callback_stage()when not specified.interval (int|float|Callable) –
if int, The number of iterations between calls to the callback function.
if float, The time interval in seconds between calls to the callback function.
- if Callable, The function to determine whether to call the callback function.
The function should take a Simulation object as an argument and return a boolean value.
Defaults to 1 (call every iteration).
- Returns:
The decorated callable object (an instance of a Callback subclass).
- Return type:
Callable
Example
>>> @callback(stage="maxwell_1", interval=100) ... def custom_field_modification(sim): ... for patch in sim.patches: ... patch.fields.ex *= 1.1 # Amplify Ex field by 10%
Hello world¶
@callback('start', interval=lambda sim: sim.itime == 0)
def hello(sim: Simulation):
print("Simulation started!")
sim = Simulation(...)
...
sim.run(100, callbacks=[hello])
External fields¶
Set static external fields by adding fields to particles’s local fields.
@callback('interpolator')
def set_static_fields(sim: Simulation):
for p in sim.patches:
for part in p.particles:
part.bz_part[:] += 10 # 10T static
part.ex_part[:] += np.sin(sim.t) # time dependent
part.ey_part[:] += np.sin(part.x/1e-6) # space dependent
Or faster with numba
@njit(parallel=True)
def set_static_fields(x, is_dead, t, ex_part):
for ipart in prange(ex_part.size):
if is_dead[ipart]:
continue
ex_part[ipart] += 10 # 10T static
ex_part[ipart] += np.sin(t) # time dependent
ex_part[ipart] += np.sin(x[ipart]/1e-6) # space dependent
@callback('interpolator')
def set_static_fields(sim: Simulation):
for p in sim.patches:
part = p.particles[ele.ispec]
set_static_fields(part.x, part.is_dead, sim.t, part.ex_part)
Reduction/Summation¶
Calculate total EM energy,
sim = Simulation(...)
ele = Electron(name='ele', ppc=10, density=...)
...
@callback('start', interval=100)
def sum_EM_enerty(sim: Simulation):
Eem = 0.0
# sum over all patches
for p in sim.patches:
f = p.fields
# NOTE: guard cells are in the [nx_per_patch:, ny_per_patch:] region
s = np.s_[:sim.nx_per_patch, :sim.ny_per_patch]
Eem += (0.5*epsilon_0*(f.ex[s]**2+f.ey[s]**2+f.ez[s]**2) +
0.5/mu_0 *(f.bx[s]**2+f.by[s]**2+f.bz[s]**2)).sum()
# sum over all mpi ranks
Eem = sim.mpi.comm.reduce(Eem)
if sim.mpi.rank > 0:
return
# print, or save to some file
print(f"{Eem=:g}")
and total electron kinetic energy.
@callback('start', interval=100)
def sum_ek(sim: Simulation):
ek = 0.0
# sum over all patches
for p in sim.patches:
part = p.particles[ele.ispec]
# select alive particles
alive = part.is_alive
ek += ((1/part.inv_gamma[alive] - 1) * ele.m/m_e * part.w[alive]).sum() # mc2
# sum over all mpi ranks
ek = sim.mpi.comm.reduce(ek)
if sim.mpi.rank > 0:
return
# print, or save to some file
print(f"{ek=:g} mc2")