import tensorflow as tf
from vindy.layers.vindy_layer import VindyLayer
[docs]
class PDFThresholdCallback(tf.keras.callbacks.Callback):
[docs]
def __init__(self, freq=1, threshold=1, on_train_end=False, **kwargs):
"""
Callback for the VINDy layer to cancel coefficients based on PDF thresholding.
This callback sets all coefficients of the VINDy layer to zero if their
corresponding probability density function at zero is above the threshold.
Parameters
----------
freq : int, default=1
Frequency of coefficient cancellation (every freq-th epochs).
threshold : float, default=1
Threshold for cancellation (coefficients with pdf(0) > threshold are zeroed).
on_train_end : bool, default=False
Perform thresholding at end of training.
**kwargs
Additional keyword arguments passed to tf.keras.callbacks.Callback.
"""
self.freq = freq
self.threshold = threshold
self.on_train_end_ = on_train_end
super().__init__(**kwargs)
[docs]
def on_epoch_end(self, epoch, logs=None):
# only save coefficients every freq epochs
if (epoch + 1) % self.freq == 0:
self.cancel_coefficients()
[docs]
def on_train_end(self, logs=None):
if self.on_train_end_:
self.cancel_coefficients()
[docs]
def cancel_coefficients(self):
"""
Cancel coefficients based on their probability density at zero.
Cancels the coefficients of the SINDy layer if their corresponding
probability density function at zero is above the threshold, i.e.,
if pdf(0) > self.threshold.
"""
sindy_layer = self.model.sindy_layer
if isinstance(sindy_layer, VindyLayer):
# get current
sindy_layer.pdf_thresholding(threshold=self.threshold)
else:
tf.print(
"Canceling coefficients is only implemented for variational SINDy"
)