Source code for vindy.callbacks.save_coefficients_callback

import tensorflow as tf


[docs] class SaveCoefficientsCallback(tf.keras.callbacks.Callback):
[docs] def __init__(self, freq=1, **kwargs): """ Callback for logging SINDy coefficients during training. Parameters ---------- freq : int, default=1 Frequency of saving the coefficients (every freq-th epoch). **kwargs Additional keyword arguments passed to ``tf.keras.callbacks.Callback``. """ self.freq = freq super().__init__(**kwargs)
[docs] def on_epoch_end(self, epoch, logs=None): # only save coefficients every freq epochs if (epoch + 1) % self.freq == 0: # add the current epoch to the logs logs = logs or {} # get the current weights of the sindy layer sindy_layer = self.model.sindy_layer coeffs = sindy_layer._coeffs # save coeffs to training history if isinstance(coeffs, list) or isinstance(coeffs, tuple): logs.update({"coeffs_mean": coeffs[1].numpy()}) logs.update({"coeffs_scale": coeffs[2].numpy()}) else: logs.update({"coeffs": coeffs})