try:
import ipywidgets as widgets
except ImportError:
pass
from preliz.internal.distribution_helper import get_distributions, process_extra
from preliz.internal.optimization import fit_to_quartile
from preliz.internal.plot_helper import (
check_inside_notebook,
create_figure,
representations,
reset_dist_panel,
)
[docs]
class QuartileInt:
"""
Prior elicitation for 1D distributions from quartiles (See :footcite:t:`Morris2014`).
Parameters
----------
q1 : float
First quartile, i.e 0.25 of the mass is below this point.
q2 : float
Second quartile, i.e 0.50 of the mass is below this point. This is also know
as the median.
q3 : float
Third quartile, i.e 0.75 of the mass is below this point.
dist_names: list
List of distributions names to be used in the elicitation. If None, almost all 1D
distributions available in PreliZ will be used. Some distributions like Uniform or
Cauchy are omitted by default.
figsize: Optional[Tuple[int, int]]
Figure size. If None it will be defined automatically.
Note
----
Use the `params` text box to parametrize distributions, for instance write
`BetaScaled(lower=-1, upper=10)` to specify the upper and lower bounds of BetaScaled
distribution. To parametrize more that one distribution use commas for example
`StudentT(nu=3), TruncatedNormal(lower=-2, upper=inf)`
References
----------
.. footbibliography::
"""
def __init__(self, q1=1, q2=2, q3=3, dist_names=None, figsize=None):
self._q1 = q1
self._q2 = q2
self._q3 = q3
self.dist = None
self._dist_names = dist_names
self._figsize = figsize
check_inside_notebook(need_widget=True)
self._widgets = self._get_widgets()
self._output = widgets.Output()
with self._output:
if self._figsize is None:
self._figsize = (8, 6)
self._fig, self._ax_fit = create_figure(self._figsize)
self._setup_observers()
self._fig.canvas.mpl_connect(
"button_release_event",
lambda event: self._match_distribution(),
)
self._match_distribution()
controls = widgets.VBox(
[
self._widgets["w_q1"],
self._widgets["w_q2"],
self._widgets["w_q3"],
self._widgets["w_extra"],
]
)
display( # noqa: F821
widgets.HBox([controls, self._widgets["w_repr"], self._widgets["w_distributions"]])
)
def _get_widgets(self):
width_entry_text = widgets.Layout(width="150px")
width_repr_text = widgets.Layout(width="250px")
width_distribution_text = widgets.Layout(width="150px", height="125px")
w_q1 = widgets.FloatText(
value=self._q1,
step=0.1,
description="q1",
disabled=False,
layout=width_entry_text,
)
w_q2 = widgets.FloatText(
value=self._q2,
step=0.1,
description="q2",
disabled=False,
layout=width_entry_text,
)
w_q3 = widgets.FloatText(
value=self._q3,
step=0.1,
description="q3",
disabled=False,
layout=width_entry_text,
)
w_extra = widgets.Textarea(
value="",
placeholder="Pass extra parameters",
description="params:",
disabled=False,
layout=width_repr_text,
)
w_repr = widgets.RadioButtons(
options=["pdf", "cdf", "ppf"],
value="pdf",
description="",
disabled=False,
layout=width_repr_text,
)
if self._dist_names is None:
default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"]
self._dist_names = [
"AsymmetricLaplace",
"BetaScaled",
"ChiSquared",
"ExGaussian",
"Exponential",
"Gamma",
"Gumbel",
"HalfNormal",
"HalfStudentT",
"InverseGamma",
"Laplace",
"LogNormal",
"Logistic",
"LogitNormal",
"Moyal",
"Normal",
"Pareto",
"Rice",
"SkewNormal",
"StudentT",
"Triangular",
"VonMises",
"Wald",
"Weibull",
"BetaBinomial",
"DiscreteWeibull",
"Geometric",
"NegativeBinomial",
"Poisson",
]
else:
default_dist = self._dist_names
w_distributions = widgets.SelectMultiple(
options=self._dist_names,
value=default_dist,
description="",
disabled=False,
layout=width_distribution_text,
)
return {
"w_q1": w_q1,
"w_q2": w_q2,
"w_q3": w_q3,
"w_extra": w_extra,
"w_repr": w_repr,
"w_distributions": w_distributions,
}
def _match_distribution(self):
q1 = float(self._widgets["w_q1"].value)
q2 = float(self._widgets["w_q2"].value)
q3 = float(self._widgets["w_q3"].value)
extra_pros = process_extra(self._widgets["w_extra"].value)
fitted_dist = None
if q1 < q2 < q3:
reset_dist_panel(self._ax_fit, yticks=False)
fitted_dist = fit_to_quartile(
get_distributions(self._widgets["w_distributions"].value), q1, q2, q3, extra_pros
)
if fitted_dist is None:
self._ax_fit.set_title("domain error")
else:
representations(fitted_dist, self._widgets["w_repr"].value, self._ax_fit)
else:
reset_dist_panel(self._ax_fit, yticks=True)
self._ax_fit.set_title("quantiles must follow the order: q1 < q2 < q3 ")
self._fig.canvas.draw()
self.dist = fitted_dist
def _setup_observers(self):
def _match_distribution_(_):
self._match_distribution()
self._widgets["w_repr"].observe(_match_distribution_)
self._widgets["w_distributions"].observe(_match_distribution_)
self._widgets["w_q1"].observe(_match_distribution_)
self._widgets["w_q2"].observe(_match_distribution_)
self._widgets["w_q3"].observe(_match_distribution_)