Source code for preliz.multidimensional.dirichlet_mode

import warnings

import numpy as np

from preliz.distributions import Beta, Dirichlet
from preliz.internal.optimization import optimize_dirichlet_mode
from preliz.internal.rcparams import rcParams


[docs] def dirichlet_mode(mode, mass=None, bound=0.01, plot=None, plot_kwargs=None, ax=None): """ Elicitate a Dirichlet distribution with a given mode and mass. Computes a Dirichlet distribution where the marginals have the specified mode and mass and their masses lie within the range mode ± bound (Adapted from :footcite:t:`Michael2017`). Parameters ---------- mode : list Mode of the Dirichlet distribution. mass : float Probability mass between within mode +- bounds. Defaults to None, which results in the value of rcParams["stats.ci_prob"] being used. bound : float Defines upper and lower bounds for the mass as mode +- bound. Defaults to 0.01. plot : bool Whether to plot the distribution. Defaults to True. plot_kwargs : dict Dictionary passed to the method ``plot_pdf()``. ax : matplotlib axes Returns ------- ax : matplotlib axes dist : Preliz Dirichlet distribution. Dirichlet distribution with fitted parameters alpha for the given mass and intervals. References ---------- .. footbibliography:: """ if mass is None: mass = rcParams["stats.ci_prob"] if plot is None: plot = rcParams["plots.show_plot"] if not 0 < mass <= 1: raise ValueError("mass should be larger than 0 and smaller or equal to 1") if not all(i > 0 for i in mode): raise ValueError("mode should be larger than 0") if not abs(sum(mode) - 1) < 0.0001: warnings.warn("The mode should sum to 1, normalising mode to sum to 1") sum_mode = sum(mode) mode = [i / sum_mode for i in mode] if plot_kwargs is None: plot_kwargs = {} lower_bounds = np.clip(np.array(mode) - bound, 0, 1) target_mass = (1 - mass) / 2 _dist = Beta() _, alpha = optimize_dirichlet_mode(lower_bounds, mode, target_mass, _dist) alpha_np = np.array(alpha) calculated_mode = (alpha_np - 1) / (alpha_np.sum() - len(alpha_np)) if np.any((np.array(mode) - calculated_mode) > 0.01): warnings.warn( f"The requested mode {mode} is different from the calculated mode {calculated_mode}." ) dirichlet_distribution = Dirichlet(alpha) if plot: ax = dirichlet_distribution.plot_pdf(**plot_kwargs, ax=ax) return ax, dirichlet_distribution