vindy.utils package
Submodules
vindy.utils.utils module
- add_lognormal_noise(trajectory, sigma)[source]
Add lognormal noise to a trajectory.
- Parameters:
trajectory (array-like) – The trajectory data to add noise to.
sigma (float) – Standard deviation parameter for the lognormal distribution.
- Returns:
Tuple containing (noisy_trajectory, noise).
- Return type:
tuple of array-like
- coefficient_distribution_gif(mean_over_epochs, scale_over_epochs, sindy_layer, outdir)[source]
Create a GIF showing how coefficient distributions evolve over epochs.
- Parameters:
mean_over_epochs (list of array-like) – Mean values of coefficients at each epoch.
scale_over_epochs (list of array-like) – Scale values of coefficients at each epoch.
sindy_layer (SindyLayer) – SINDy layer instance with visualization methods.
outdir (str) – Output directory for saving frames and the GIF.
- coefficient_distributions_to_csv(sindy_layer, outdir, var_names=[], param_names=[])[source]
Save the coefficient distributions of the SINDy layer to CSV files.
- Parameters:
sindy_layer (SindyLayer) – SINDy layer containing coefficient distributions.
outdir (str) – Output directory for CSV files.
var_names (list of str, optional) – Names of state variables (default is []).
param_names (list of str, optional) – Names of parameters (default is []).
- create_result_directory(base_dir: str, model_name: str) str[source]
Create a result directory for saving model outputs.
- Parameters:
base_dir (str) – Base directory for results.
model_name (str) – Name of the model/experiment.
- Returns:
Path to the created result directory.
- Return type:
str
- get_config()[source]
Import and return the config module.
This function handles the import of the examples config module with proper fallbacks so the examples can be run both when examples is a package and when it’s just a directory with config.py next to the example scripts.
- Returns:
The config module object.
- Return type:
module
- Raises:
ImportError – If config.py doesn’t exist or can’t be imported.
- get_latent_initial_conditions(veni, x, dxdt, dxddt, mean_or_sample)[source]
Compute the initial conditions in latent space for integration.
Computes initial conditions based on the provided state and its derivatives.
- Parameters:
veni (VENI) – The VENI model instance.
x (array-like) – The state data array.
dxdt (array-like) – The first time derivative of the state data array.
dxddt (array-like or None) – The second time derivative of the state data array (can be None if not available).
mean_or_sample (str) – Whether to compute the mean initial condition or sample from the distribution (“mean” or “sample”).
- Returns:
Initial conditions in latent space.
- Return type:
array-like
- log_model_summary(veni, result_dir: str = None)[source]
Log a summary of the VENI model architecture.
- Parameters:
veni (VENI) – The VENI model instance.
result_dir (str, optional) – Directory to save the summary text file (default is None).
- perform_forward_uq(veni, sim_ids, n_traj, n_sims, n_timesteps, t, x, dxdt, dxddt=None, params=None, sigma=3)[source]
Perform forward uncertainty quantification by sampling trajectories.
Samples trajectories from the SINDy model. The function is flexible with optional dxddt and params (pass None if not available).
- Parameters:
veni (VENI) – The VENI model instance.
sim_ids (list of int) – List of simulation indices to process.
n_traj (int) – Number of trajectories to sample for each simulation.
n_sims (int) – Total number of simulations in the dataset.
n_timesteps (int) – Number of timesteps in each simulation.
t (array-like) – Time vector.
x (array-like) – State data.
dxdt (array-like) – First time derivative of state data.
dxddt (array-like, optional) – Second time derivative of state data (default is None).
params (array-like, optional) – Additional parameters for integration (default is None).
sigma (float, optional) – Number of standard deviations for confidence intervals (default is 3).
- Returns:
Dictionary with keys: sampled_times, latent_trajectories_samples, mean_latent_samples, std_latent_samples, mean_latent, lower_bound_latent, upper_bound_latent, z, dzdt.
- Return type:
dict
- perform_inference(veni, sim_ids, n_sims, n_timesteps, t, x, dxdt=None, params=None)[source]
Perform inference on test trajectories and plot the results.
- Parameters:
veni (VENI) – The trained VENI model.
sim_ids (list of int) – List of test trajectory indices.
n_sims (int) – Number of simulations.
n_timesteps (int) – Number of timesteps in each test trajectory.
t (array-like) – Test time steps.
x (array-like) – Scaled test data.
dxdt (array-like, optional) – Scaled test data derivatives (default is None).
params (array-like, optional) – Test parameters (default is None).
- Returns:
Tuple containing (t_preds, z_preds) - predicted trajectories and their corresponding time steps.
- Return type:
tuple
- plot_coefficients_train_history(trainhist, result_dir)[source]
Plot the evolution of SINDy coefficients during training.
- Parameters:
trainhist (dict) – Training history dictionary.
result_dir (str) – Directory to save the plot.
- plot_inference_results(t_preds, z_preds, T, Z, sim_ids, state_id=0)[source]
Plot inference results for test trajectories.
- Parameters:
t_preds (list of array-like) – Predicted time steps for each trajectory.
z_preds (list of array-like) – Predicted latent states for each trajectory.
T (array-like) – True time steps.
Z (array-like) – True latent states.
sim_ids (list of int) – List of simulation indices to plot.
state_id (int, optional) – Index of the state variable to plot (default is 0).
- plot_train_history(trainhist, result_dir, validation=True)[source]
Plot training history including loss curves.
- Parameters:
trainhist (dict) – Training history dictionary containing loss values.
result_dir (str) – Directory to save the plot.
validation (bool, optional) – Whether to include validation loss in the plot (default is True).
- set_seed(seed: int)[source]
Set seed for reproducibility in TensorFlow, NumPy, and Python’s random module.
- Parameters:
seed (int) – The seed value to set.
- switch_data_format(data, n_sims, n_timesteps, spatial_shape=None, target_format='auto')[source]
Convert between vectorized (2D), simulation-wise flattened (3D), and full spatial (5D) data formats.
Parameters - data: np.ndarray. One of:
2D: (n_sims * n_timesteps, features)
3D: (n_sims, n_timesteps, features)
5D: (n_sims, n_timesteps, Nx, Ny, channels)
n_sims: int, number of simulations
n_timesteps: int, timesteps per simulation
spatial_shape: optional tuple describing spatial dims. Accepts (Nx, Ny, channels) or (N, channels) or (Nx, Ny). When converting to/from 5D, this is required unless it can be inferred unambiguously from feature size.
target_format: ‘auto’ (default), ‘2d’, ‘3d’, or ‘5d’. When ‘auto’, the function chooses a sensible target based on input.
Returns - Converted np.ndarray in requested format.
Examples - 2D -> 5D: provide spatial_shape=(Nx,Ny,channels) and target_format=’5d’ - 5D -> 2D: target_format=’2d’ or rely on ‘auto’ to get 3D flattened by default
- uq_plots(sampled_times, mean_latent, mean_latent_samples, std_latent_samples, t_test, z_test, test_ids, state_id=0)[source]
Generate uncertainty quantification plots.
- Parameters:
sampled_times (list of array-like) – Time points for sampled UQ trajectories.
mean_latent (array-like) – Mean trajectories from deterministic integration.
mean_latent_samples (array-like) – Mean of sampled trajectories.
std_latent_samples (array-like) – Standard deviation of sampled trajectories.
t_test (array-like) – Test time steps.
z_test (array-like) – Latent states for test data.
test_ids (list of int) – List of test trajectory indices to plot.
state_id (int, optional) – Index of the state variable to plot (default is 0).
- validate_data_path(data_path: str, zenodo_doi: str = '10.5281/zenodo.18313843')[source]
Validate that a data file exists and provide a helpful error message if not.
- Parameters:
data_path (str) – Path to the data file.
zenodo_doi (str, optional) – Zenodo DOI for downloading the data (default is “10.5281/zenodo.18313843”).
- Raises:
FileNotFoundError – If the data file does not exist.