Source code for vindy.distributions.base_distribution

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from abc import abstractmethod, ABC

[docs] class BaseDistribution(tf.keras.layers.Layer, ABC): """ Base class for probabilistic distributions used in variational encoders. Subclasses should implement sampling and log-probability computations. Methods ------- call(inputs) Return samples and any auxiliary outputs (e.g. mean/logvar). """
[docs] @abstractmethod def call(self, inputs): """ Sample from the distribution. Parameters ---------- inputs : array-like Inputs used to parameterize the distribution (for instance mean/logvar). Returns ------- tuple or tf.Tensor Samples (and optionally auxiliary statistics). """ pass
[docs] @abstractmethod def KL_divergence(self): """ Compute the KL divergence between two distributions. Returns ------- tf.Tensor Scalar KL divergence. """ pass
[docs] @abstractmethod def prob_density_fcn(self, x, mean, scale): """ Probability density function. Parameters ---------- x : array-like Points at which to evaluate the density. mean : float or array-like Distribution mean/loc parameter. scale : float or array-like Scale parameter (std, scale, etc.). Returns ------- array-like Density values at x. """ pass
[docs] @abstractmethod def variance(self, scale): """ Variance as a function of the distribution scale parameter. Parameters ---------- scale : float or array-like Scale parameter of the distribution. Returns ------- float or array-like Variance corresponding to the provided scale. """ pass
[docs] def reverse_log(self, log_scale): """ Converts the log scale to scale following s = exp(log(s)) = exp(log_scale) Parameters ---------- log_scale : array-like Logarithm of the scale parameter. Returns ------- array-like Scale (exp(log_scale)). """ return tf.exp(log_scale)
[docs] def plot(self, mean, scale, ax=None): """ Plots the probability density function of the distribution. Parameters ---------- mean : float or array-like Mean/loc of the distribution. scale : float or array-like Scale parameter. ax : matplotlib.axes.Axes, optional Axis to draw on. If None, uses current axis. """ if ax is None: ax = plt.gca() variance = self.variance(scale) x = (np.linspace(-1*variance, 1*variance, 3000) + mean) # find first positive value try: idx = np.where(x > 0)[0][0] x = np.insert(x, idx, 0) except IndexError: pass if isinstance(x, tf.Tensor): x = x.numpy().squeeze() x = x.squeeze() # plt.figure() # get current axis ax.plot(x, self.prob_density_fcn(x, mean, scale)) # fill area under curve ax.fill_between(x, self.prob_density_fcn(x, mean, scale), alpha=0.3)
# plt.show()