Source code for vindy.callbacks.save_coefficients_callback

import tensorflow as tf


[docs] class SaveCoefficientsCallback(tf.keras.callbacks.Callback): def __init__(self, freq=1, **kwargs): """ Callback for the SINDy layer. This callback is used to log the coefficients during training. Parameters ---------- freq : int Frequency of the saving of the coefficients. kwargs : dict Additional keyword arguments. """ self.freq = freq super().__init__(**kwargs)
[docs] def on_epoch_end(self, epoch, logs=None): """ Called at the end of an epoch during training. Parameters ---------- epoch : int The current epoch number. logs : dict, optional Dictionary of logs from the training process. """ # only save coefficients every freq epochs if (epoch + 1) % self.freq == 0: # add the current epoch to the logs logs = logs or {} # get the current weights of the sindy layer sindy_layer = self.model.sindy_layer coeffs = sindy_layer._coeffs # save coeffs to training history if isinstance(coeffs, list) or isinstance(coeffs, tuple): logs.update({"coeffs_mean": coeffs[1].numpy()}) logs.update({"coeffs_scale": coeffs[2].numpy()}) else: logs.update({"coeffs": coeffs})