"""Continuous multivariate probability distributions."""
import warnings
from copy import copy
import numpy as np
try:
from ipywidgets import interactive, widgets
except ImportError:
pass
from scipy import stats
from preliz.distributions.beta import Beta
from preliz.distributions.distributions_multivariate import Continuous
from preliz.distributions.normal import Normal
from preliz.internal.distribution_helper import all_not_none
from preliz.internal.plot_helper import check_inside_notebook, get_slider
from preliz.internal.plot_helper_multivariate import plot_dirichlet, plot_mvnormal
eps = np.finfo(float).eps
[docs]
class Dirichlet(Continuous):
r"""
Dirichlet distribution.
.. math::
f(\mathbf{x}|\mathbf{a}) =
\frac{\Gamma(\sum_{i=1}^k a_i)}{\prod_{i=1}^k \Gamma(a_i)}
\prod_{i=1}^k x_i^{a_i - 1}
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
from preliz import Dirichlet
_, axes = plt.subplots(2, 2)
alphas = [[0.5, 0.5, 0.5], [1, 1, 1], [5, 5, 5], [5, 2, 1]]
for alpha, ax in zip(alphas, axes.ravel()):
pz.Dirichlet(alpha).plot_pdf(marginals=False, ax=ax)
======== ===============================================
Support :math:`x_i \in (0, 1)` for :math:`i \in \{1, \ldots, K\}`
such that :math:`\sum x_i = 1`
Mean :math:`\dfrac{a_i}{\sum a_i}`
Variance :math:`\dfrac{a_i - \sum a_0}{a_0^2 (a_0 + 1)}`
where :math:`a_0 = \sum a_i`
======== ===============================================
Parameters
----------
alpha : array of floats
Concentration parameter (alpha > 0).
"""
def __init__(self, alpha=None):
super().__init__()
self.dist = copy(stats.dirichlet)
self.marginal = Beta
self.support = (eps, 1 - eps)
self._parametrization(alpha)
def _parametrization(self, alpha=None):
self.param_names = ("alpha",)
self.params_support = ((eps, np.inf),)
self.alpha = alpha
if alpha is not None:
self._update(alpha)
def _get_frozen(self):
frozen = None
if all_not_none(self):
frozen = self.dist(self.alpha)
return frozen
def _update(self, alpha):
self.alpha = np.array(alpha, dtype=float)
self.params = (self.alpha,)
self._update_rv_frozen()
def _fit_mle(self, sample, **kwargs):
raise NotImplementedError
def mode(self):
return (
(self.alpha - 1) / (np.sum(self.alpha) - len(self.alpha))
if np.all(self.alpha > 1)
else None
)
[docs]
def plot_pdf(
self,
marginals=True,
pointinterval=False,
interval=None,
levels=None,
support="full",
baseline=True,
legend="title",
figsize=None,
ax=None,
):
"""
Plot the pdf of the marginals or the joint pdf of the simplex.
The joint representation is only available for a dirichlet with an alpha of length 3.
Parameters
----------
marginals : True
Defaults to True, plot the marginal distributions, if False plot the joint distribution
(only valid for an alpha of length 3).
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False. If True the default is to
plot the median and two interquantile ranges.
interval : str
Type of interval. Available options are highest density interval `"hdi"`,
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
Defaults to the value in rcParams["stats.ci_kind"].
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
support : str:
If ``full`` use the finite end-points to set the limits of the plot. For unbounded
end-points or if ``restricted`` use the 0.001 and 0.999 quantiles to set the limits.
baseline : bool
Whether to include a baseline in the plot. Defaults to True. Only used when
``marginals=True``.
legend : str
Whether to include a string with the distribution and its parameter as a ``"title"``
or not include them ``None``.
figsize : tuple
Size of the figure
ax : matplotlib axis
Axis to plot on
Returns
-------
ax : matplotlib axis
"""
return plot_dirichlet(
self,
"pdf",
marginals,
pointinterval,
interval,
levels,
support,
baseline,
legend,
figsize,
ax,
)
[docs]
def plot_cdf(
self,
pointinterval=False,
interval=None,
levels=None,
support="full",
legend="title",
figsize=None,
ax=None,
):
"""
Plot the cumulative distribution function.
Parameters
----------
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False. If True the default is to
plot the median and two interquantile ranges.
interval : str
Type of interval. Available options are highest density interval `"hdi"`,
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
Defaults to the value in rcParams["stats.ci_kind"].
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
support : str:
If ``full`` use the finite end-points to set the limits of the plot. For unbounded
end-points or if ``restricted`` use the 0.001 and 0.999 quantiles to set the limits.
legend : str
Whether to include a string with the distribution and its parameter as a ``"title"``
or not include them ``None``.
figsize : tuple
Size of the figure
ax : matplotlib axis
Axis to plot on
Returns
-------
ax : matplotlib axis
"""
return plot_dirichlet(
self,
"cdf",
"marginals",
pointinterval,
interval,
levels,
support,
None,
legend,
figsize,
ax,
)
[docs]
def plot_ppf(
self,
pointinterval=False,
interval=None,
levels=None,
legend="title",
figsize=None,
ax=None,
):
"""
Plot the quantile function.
Parameters
----------
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False. If True the default is to
plot the median and two interquantile ranges.
interval : str
Type of interval. Available options are highest density interval `"hdi"`,
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
Defaults to the value in rcParams["stats.ci_kind"].
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
legend : str
Whether to include a string with the distribution and its parameter as a ``"title"``
or not include them ``None``.
figsize : tuple
Size of the figure
ax : matplotlib axis
Axis to plot on
Returns
-------
ax : matplotlib axis
"""
return plot_dirichlet(
self,
"ppf",
"marginals",
pointinterval,
interval,
levels,
None,
None,
legend,
figsize,
ax,
)
[docs]
def plot_sf(
self,
pointinterval=False,
interval=None,
levels=None,
support="full",
legend="title",
figsize=None,
ax=None,
):
"""
Plot the survival function (1 - CDF).
Parameters
----------
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False. If True the default is to
plot the median and two interquantile ranges.
interval : str
Type of interval. Available options are highest density interval `"hdi"`,
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
Defaults to the value in rcParams["stats.ci_kind"].
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
support : str:
If ``full`` use the finite end-points to set the limits of the plot. For unbounded
end-points or if ``restricted`` use the 0.001 and 0.999 quantiles to set the limits.
legend : str
Whether to include a string with the distribution and its parameter as a ``"title"``
or not include them ``None``.
figsize : tuple
Size of the figure
ax : matplotlib axis
Axis to plot on
Returns
-------
ax : matplotlib axis
"""
return plot_dirichlet(
self,
"sf",
"marginals",
pointinterval,
interval,
levels,
support,
None,
legend,
figsize,
ax,
)
[docs]
def plot_isf(
self,
pointinterval=False,
interval=None,
levels=None,
legend="title",
figsize=None,
ax=None,
):
"""
Plot the inverse survival function.
Parameters
----------
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False. If True the default is to
plot the median and two interquantile ranges.
interval : str
Type of interval. Available options are highest density interval `"hdi"`,
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
Defaults to the value in rcParams["stats.ci_kind"].
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
legend : str
Whether to include a string with the distribution and its parameter as a ``"title"``
or not include them ``None``.
figsize : tuple
Size of the figure
ax : matplotlib axis
Axis to plot on
Returns
-------
ax : matplotlib axis
"""
return plot_dirichlet(
self,
"isf",
"marginals",
pointinterval,
interval,
levels,
None,
None,
legend,
figsize,
ax,
)
[docs]
def plot_interactive(
self,
kind="pdf",
xy_lim="both",
pointinterval=True,
interval=None,
levels=None,
baseline=True,
legend="title",
figsize=None,
):
"""
Interactive exploration of parameters.
Parameters
----------
kind : str:
Type of plot. Available options are `pdf`, `cdf` and `ppf`.
xy_lim : str or tuple
Set the limits of the x-axis and/or y-axis.
Defaults to `"both"`, the limits of both axes are fixed for all subplots.
Use `"auto"` for automatic rescaling of x-axis and y-axis.
Or set them manually by passing a tuple of 4 elements,
the first two for x-axis, the last two for y-axis. The tuple can have `None`.
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False.
If `True` the default is to plot the median and two inter-quantiles ranges.
interval : str
Type of interval. Available options are the highest density interval `"hdi"`,
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
baseline : bool
Whether to include a baseline in the plot. Defaults to True.
Only applicable for `pdf` plots.
legend : str
Whether to include a string with the distribution and its parameter as a ``"title"``
or not include them ``None``.
figsize : tuple
Size of the figure
"""
check_inside_notebook()
if kind != "pdf" and baseline:
warnings.warn("baseline is only applicable to PDF plots")
args = dict(zip(self.param_names, self.params))
self.__init__(**args)
if kind == "pdf":
w_checkbox_marginals = widgets.Checkbox(
value=True,
description="marginals",
disabled=False,
indent=False,
)
plot_widgets = {"marginals": w_checkbox_marginals}
else:
plot_widgets = {}
for index, dim in enumerate(self.params[0]):
plot_widgets[f"alpha-{index + 1}"] = get_slider(
f"alpha-{index + 1}", dim, *self.params_support[0]
)
def plot(**args):
if kind == "pdf":
marginals = args.pop("marginals")
params = {"alpha": np.asarray(list(args.values()), dtype=float)}
self.__init__(**params)
if kind == "pdf":
plot_dirichlet(
self,
"pdf",
marginals,
pointinterval,
interval,
levels,
"full",
baseline,
legend,
figsize,
None,
xy_lim,
)
elif kind == "cdf":
plot_dirichlet(
self,
"cdf",
"marginals",
pointinterval,
interval,
levels,
"full",
None,
legend,
figsize,
None,
xy_lim,
)
elif kind == "ppf":
plot_dirichlet(
self,
"cdf",
"marginals",
pointinterval,
interval,
levels,
None,
None,
legend,
figsize,
None,
xy_lim,
)
return interactive(plot, **plot_widgets)
[docs]
class MvNormal(Continuous):
r"""
Multivariate Normal distribution.
.. math::
f(x \mid \mu, T) =
\frac{|T|^{1/2}}{(2\pi)^{k/2}}
\exp\left\{ -\frac{1}{2} (x-\mu)^{\prime} T (x-\mu) \right\}
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
from preliz import MvNormal
_, axes = plt.subplots(2, 2, figsize=(9, 9), sharex=True, sharey=True)
mus = [[0., 0], [3, -2], [0., 0], [0., 0]]
sigmas = [np.eye(2), np.eye(2), np.array([[2, 2], [2, 4]]), np.array([[2, -2], [-2, 4]])]
for mu, sigma, ax in zip(mus, sigmas, axes.ravel()):
MvNormal(mu, sigma).plot_pdf(marginals=False, ax=ax)
======== ==========================
Support :math:`x \in \mathbb{R}^k`
Mean :math:`\mu`
Variance :math:`T^{-1}`
======== ==========================
MvNormal distribution has 2 alternative parameterizations. In terms of the mean and
the covariance matrix, or in terms of the mean and the precision matrix.
The link between the 2 alternatives is given by
.. math::
T = \Sigma^{-1}
Parameters
----------
mu : array of floats
Vector of means.
cov : array of floats, optional
Covariance matrix.
tau : array of floats, optional
Precision matrix.
"""
def __init__(self, mu=None, cov=None, tau=None):
super().__init__()
self.dist = copy(stats.multivariate_normal)
self.marginal = Normal
self.support = (-np.inf, np.inf)
self._parametrization(mu, cov, tau)
def _parametrization(self, mu=None, cov=None, tau=None):
if all_not_none(cov, tau):
raise ValueError("Incompatible parametrization. Either use mu and cov, or mu and tau.")
names = ("mu", "cov")
self.params_support = ((-np.inf, np.inf), (eps, np.inf))
if tau is not None:
self.tau = tau
cov = np.linalg.inv(tau)
names = ("mu", "tau")
self.mu = mu
self.cov = cov
self.param_names = names
if mu is not None and cov is not None:
self._update(mu, cov)
def _get_frozen(self):
frozen = None
if all_not_none(self):
frozen = self.dist(mean=self.mu, cov=self.cov, allow_singular=True)
return frozen
def _update(self, mu, cov):
self.mu = np.array(mu, dtype=float)
self.cov = np.array(cov, dtype=float)
self.tau = np.linalg.inv(cov)
if self.param_names[1] == "cov":
self.params = (self.mu, self.cov)
elif self.param_names[1] == "tau":
self.params = (self.mu, self.tau)
self._update_rv_frozen()
self.rv_frozen.var = lambda: np.diag(self.cov)
def _fit_mle(self, sample, **kwargs):
raise NotImplementedError
def mode(self):
return self.mu
[docs]
def plot_pdf(
self,
marginals=True,
pointinterval=False,
interval=None,
levels=None,
support="full",
baseline=True,
legend="title",
figsize=None,
ax=None,
):
"""
Plot the pdf of the marginals or the joint pdf.
The joint representation is only available for a 2D Multivariate Normal.
Parameters
----------
marginals : True
Defaults to True, plot the marginal distributions, if False plot the joint distribution
(only valid for a bivariate normal).
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False. If True the default is to
plot the median and two interquantiles ranges.
interval : str
Type of interval. Available options are highest density interval `"hdi"`,
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
Defaults to the value in rcParams["stats.ci_kind"].
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
support : str:
If ``full`` use the finite end-points to set the limits of the plot. For unbounded
end-points or if ``restricted`` use the 0.001 and 0.999 quantiles to set the limits.
baseline : bool
Whether to include a baseline in the plot. Defaults to True. Only used when
``marginals=True``.
legend : str
Whether to include a string with the distribution and its parameter as a ``"title"``
or not include them ``None``.
figsize : tuple
Size of the figure
ax : matplotlib axis
Axis to plot on
Returns
-------
ax : matplotlib axis
"""
return plot_mvnormal(
self,
"pdf",
marginals,
pointinterval,
interval,
levels,
support,
baseline,
legend,
figsize,
ax,
)
[docs]
def plot_cdf(
self,
pointinterval=False,
interval=None,
levels=None,
support="full",
legend="title",
figsize=None,
ax=None,
):
"""
Plot the cumulative distribution function.
Parameters
----------
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False. If True the default is to
plot the median and two interquantiles ranges.
interval : str
Type of interval. Available options are highest density interval `"hdi"`,
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
Defaults to the value in rcParams["stats.ci_kind"].
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
support : str:
If ``full`` use the finite end-points to set the limits of the plot. For unbounded
end-points or if ``restricted`` use the 0.001 and 0.999 quantiles to set the limits.
legend : str
Whether to include a string with the distribution and its parameter as a ``"title"``
or not include them ``None``.
figsize : tuple
Size of the figure
ax : matplotlib axis
Axis to plot on
Returns
-------
ax : matplotlib axis
"""
return plot_mvnormal(
self,
"cdf",
"marginals",
pointinterval,
interval,
levels,
support,
None,
legend,
figsize,
ax,
)
[docs]
def plot_ppf(
self,
pointinterval=False,
interval=None,
levels=None,
legend="title",
figsize=None,
ax=None,
):
"""
Plot the quantile function.
Parameters
----------
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False. If True the default is to
plot the median and two interquantiles ranges.
interval : str
Type of interval. Available options are highest density interval `"hdi"`,
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
Defaults to the value in rcParams["stats.ci_kind"].
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
legend : str
Whether to include a string with the distribution and its parameter as a ``"title"``
or not include them ``None``.
figsize : tuple
Size of the figure
ax : matplotlib axis
Axis to plot on
Returns
-------
ax : matplotlib axis
"""
return plot_mvnormal(
self,
"ppf",
"marginals",
pointinterval,
interval,
levels,
None,
None,
legend,
figsize,
ax,
)
[docs]
def plot_sf(
self,
pointinterval=False,
interval=None,
levels=None,
support="full",
legend="title",
figsize=None,
ax=None,
):
"""
Plot the survival function (1 - CDF).
Parameters
----------
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False. If True the default is to
plot the median and two interquantiles ranges.
interval : str
Type of interval. Available options are highest density interval `"hdi"`,
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
Defaults to the value in rcParams["stats.ci_kind"].
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
support : str:
If ``full`` use the finite end-points to set the limits of the plot. For unbounded
end-points or if ``restricted`` use the 0.001 and 0.999 quantiles to set the limits.
legend : str
Whether to include a string with the distribution and its parameter as a ``"title"``
or not include them ``None``.
figsize : tuple
Size of the figure
ax : matplotlib axis
Axis to plot on
Returns
-------
ax : matplotlib axis
"""
return plot_mvnormal(
self,
"sf",
"marginals",
pointinterval,
interval,
levels,
support,
None,
legend,
figsize,
ax,
)
[docs]
def plot_isf(
self,
pointinterval=False,
interval=None,
levels=None,
legend="title",
figsize=None,
ax=None,
):
"""
Plot the inverse survival function.
Parameters
----------
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False. If True the default is to
plot the median and two interquantiles ranges.
interval : str
Type of interval. Available options are highest density interval `"hdi"`,
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
Defaults to the value in rcParams["stats.ci_kind"].
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
legend : str
Whether to include a string with the distribution and its parameter as a ``"title"``
or not include them ``None``.
figsize : tuple
Size of the figure
ax : matplotlib axis
Axis to plot on
Returns
-------
ax : matplotlib axis
"""
return plot_mvnormal(
self,
"isf",
"marginals",
pointinterval,
interval,
levels,
None,
None,
legend,
figsize,
ax,
)
[docs]
def plot_interactive(
self,
kind="pdf",
xy_lim="both",
pointinterval=True,
interval=None,
levels=None,
baseline=True,
legend="title",
figsize=None,
):
"""
Interactive exploration of parameters.
Parameters
----------
kind : str:
Type of plot. Available options are `pdf`, `cdf` and `ppf`.
xy_lim : str or tuple
Set the limits of the x-axis and/or y-axis.
Defaults to `"both"`, the limits of both axes are fixed for all subplots.
Use `"auto"` for automatic rescaling of x-axis and y-axis.
Or set them manually by passing a tuple of 4 elements,
the first two for x-axis, the last two for y-axis. The tuple can have `None`.
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False.
If `True` the default is to plot the median and two inter-quantiles ranges.
interval : str
Type of interval. Available options are the highest density interval `"hdi"`,
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
Defaults to the value in rcParams["stats.ci_kind"].
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
baseline : bool
Whether to include a baseline in the plot. Defaults to True.
Only applicable for `pdf` plots.
legend : str
Whether to include a string with the distribution and its parameter as a ``"title"``
or not include them ``None``.
figsize : tuple
Size of the figure
"""
check_inside_notebook()
if kind != "pdf" and baseline:
warnings.warn("baseline is only applicable to PDF plots")
args = dict(zip(self.param_names, self.params))
cov, tau = args.get("cov", None), args.get("tau", None)
self.__init__(**args)
if kind == "pdf":
w_checkbox_marginals = widgets.Checkbox(
value=True,
description="marginals",
disabled=False,
indent=False,
)
plot_widgets = {"marginals": w_checkbox_marginals}
else:
plot_widgets = {}
for index, mu in enumerate(self.params[0]):
plot_widgets[f"mu-{index + 1}"] = get_slider(
f"mu-{index + 1}", mu, *self.params_support[0]
)
def plot(**args):
if kind == "pdf":
marginals = args.pop("marginals")
params = {"mu": np.asarray(list(args.values()), dtype=float), "cov": cov, "tau": tau}
self.__init__(**params)
if kind == "pdf":
plot_mvnormal(
self,
"pdf",
marginals,
pointinterval,
interval,
levels,
"full",
baseline,
legend,
figsize,
None,
xy_lim,
)
elif kind == "cdf":
plot_mvnormal(
self,
"cdf",
"marginals",
pointinterval,
interval,
levels,
"full",
None,
legend,
figsize,
None,
xy_lim,
)
elif kind == "ppf":
plot_mvnormal(
self,
"ppf",
"marginals",
pointinterval,
interval,
levels,
None,
None,
legend,
figsize,
None,
xy_lim,
)
return interactive(plot, **plot_widgets)