from math import ceil, floor
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import patches
try:
import ipywidgets as widgets
except ImportError:
pass
from preliz.distributions import all_continuous, all_discrete
from preliz.internal.distribution_helper import get_distributions, process_extra
from preliz.internal.optimization import fit_to_epdf
from preliz.internal.plot_helper import check_inside_notebook, representations
[docs]
class Roulette:
"""
Prior elicitation for 1D distribution using the roulette method (See :footcite:t:`Morris2014`).
Draw 1D distributions using a grid as input.
Parameters
----------
x_min: Optional[float]
Minimum value for the domain of the grid and fitted distribution.
x_max: Optional[float]
Maximum value for the domain of the grid and fitted distribution.
nrows: Optional[int]
Number of rows for the grid. Defaults to 10.
ncols: Optional[int]
Number of columns for the grid. Defaults to 11.
dist_names: list
List of distribution names to be used in the elicitation.
Defaults to None. The pre-selected distributions are ["Normal", "BetaScaled",
"Gamma", "LogNormal", "StudentT"], but almost all 1D PreliZ's distributions
are available to be selected from the menu with some exceptions like Uniform
or Cauchy.
params: Optional[str]
Extra parameters to be passed to the distributions. The format is a string with the
PreliZ's distribution name followed by the argument to fix.
For example: "TruncatedNormal(lower=0), StudentT(nu=8)". If you use the ``params``
text area, quotation marks are not necessary.
figsize: Optional[Tuple[int, int]]
Figure size. If None, it will be defined automatically.
Returns
-------
Roulette object
The object has many attributes, but the most important are:
- dist: The fitted distribution.
- inputs: A tuple with the x values, the empirical pdf, the total
chips, the x_min, the x_max, the number of rows, and the number of columns.
"""
def __init__(
self, x_min=0, x_max=10, nrows=10, ncols=11, dist_names=None, params=None, figsize=None
):
self._x_min = x_min
self._x_max = x_max
self._nrows = nrows
self._ncols = ncols
self._dist_names = dist_names
self._figsize = figsize
self._w_extra = params
self.dist = None
self.inputs = None
check_inside_notebook(need_widget=True)
self._widgets = self._get_widgets()
self._output = widgets.Output()
with self._output:
if self._figsize is None:
self._figsize = (8, 6)
self._fig, self._ax_grid, self._ax_fit = self._create_figure()
self._coll = self._create_grid()
self._grid = _Rectangles(self._fig, self._coll, self._nrows, self._ncols, self._ax_grid)
self._setup_observers()
self._fig.canvas.mpl_connect("button_release_event", lambda event: self._on_leave_fig())
controls = widgets.VBox(
[
self._widgets["w_x_min"],
self._widgets["w_x_max"],
self._widgets["w_nrows"],
self._widgets["w_ncols"],
self._widgets["w_extra"],
]
)
control_distribution = widgets.VBox(
[
self._widgets["w_checkbox_cont"],
self._widgets["w_checkbox_disc"],
self._widgets["w_checkbox_none"],
]
)
display( # noqa: F821
widgets.HBox(
[
controls,
self._widgets["w_repr"],
self._widgets["w_distributions"],
control_distribution,
]
)
)
def _create_figure(self):
fig, axes = plt.subplots(2, 1, figsize=self._figsize, constrained_layout=True)
ax_grid = axes[0]
ax_fit = axes[1]
ax_fit.set_yticks([])
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_position = "right"
return fig, ax_grid, ax_fit
def _create_grid(self):
xx = np.arange(self._ncols)
yy = np.arange(self._nrows)
if self._ncols < 11:
num = self._ncols
else:
num = 11
self._ax_grid.set(
xticks=np.linspace(0, self._ncols - 1, num=num) + 0.5,
xticklabels=[f"{i:.1f}" for i in np.linspace(self._x_min, self._x_max, num=num)],
)
coll = np.zeros((self._nrows, self._ncols), dtype=object)
for idx, xi in enumerate(xx):
for idy, yi in enumerate(yy):
sq = patches.Rectangle((xi, yi), 1, 1, fill=True, facecolor="0.8", edgecolor="w")
self._ax_grid.add_patch(sq)
coll[idy, idx] = sq
self._ax_grid.set_yticks([])
self._ax_grid.relim()
self._ax_grid.autoscale_view()
return coll
def _on_leave_fig(self):
extra_pros = process_extra(self._widgets["w_extra"].value)
x_vals, epdf, mean, std, filled_columns = self._weights_to_pdf()
fitted_dist = None
if filled_columns > 1:
selected_distributions = get_distributions(self._widgets["w_distributions"].value)
if selected_distributions:
self._reset_dist_panel(yticks=False)
fitted_dist = fit_to_epdf(
selected_distributions,
x_vals,
epdf,
mean,
std,
self._x_min,
self._x_max,
extra_pros,
)
if fitted_dist is None:
self._ax_fit.set_title("domain error")
else:
representations(fitted_dist, self._widgets["w_repr"].value, self._ax_fit)
else:
self._reset_dist_panel(yticks=True)
self._fig.canvas.draw()
self.inputs = (
x_vals,
epdf,
sum(self._grid._weights.values()),
self._x_min,
self._x_max,
self._nrows,
self._ncols,
)
self.dist = fitted_dist
def _weights_to_pdf(self):
step = (self._x_max - self._x_min) / (self._ncols - 1)
x_vals = [(k + 0.5) * step + self._x_min for k, v in self._grid._weights.items() if v != 0]
total = sum(self._grid._weights.values())
epdf = [v / total for v in self._grid._weights.values() if v != 0]
mean = sum(prob * value for value, prob in zip(x_vals, epdf))
std = (sum(prob * (value - mean) ** 2 for value, prob in zip(x_vals, epdf))) ** 0.5
return x_vals, epdf, mean, std, len(x_vals)
def _update_grid(self):
self._ax_grid.cla()
self._coll = self._create_grid()
self._grid._coll = self._coll
self._grid._ncols = self._ncols
self._grid._nrows = self._nrows
self._grid._weights = {k: 0 for k in range(0, self._ncols)}
self._reset_dist_panel(yticks=True)
self._ax_grid.set_yticks([])
self._ax_grid.relim()
self._ax_grid.autoscale_view()
self._fig.canvas.draw()
def _reset_dist_panel(self, yticks):
self._ax_fit.cla()
if yticks:
self._ax_fit.set_yticks([])
self._ax_fit.set_xlim(self._x_min, self._x_max)
self._ax_fit.relim()
self._ax_fit.autoscale_view()
def _handle_checkbox_widget(self):
if self._widgets["w_checkbox_none"].value:
self._widgets["w_checkbox_disc"].value = False
self._widgets["w_checkbox_cont"].value = False
return []
all_cls = []
if self._widgets["w_checkbox_cont"].value:
all_cls += list(
cls.__name__
for cls in all_continuous
if cls.__name__ in self._widgets["w_distributions"].options
)
if self._widgets["w_checkbox_disc"].value:
all_cls += list(
cls.__name__
for cls in all_discrete
if cls.__name__ in self._widgets["w_distributions"].options
)
return all_cls
def _get_widgets(self):
width_entry_text = widgets.Layout(width="150px")
width_repr_text = widgets.Layout(width="250px")
width_distribution_text = widgets.Layout(width="150px", height="125px")
w_x_min = widgets.FloatText(
value=self._x_min,
step=1,
description="x_min:",
disabled=False,
layout=width_entry_text,
)
w_x_max = widgets.FloatText(
value=self._x_max,
step=1,
description="x_max:",
disabled=False,
layout=width_entry_text,
)
w_nrows = widgets.BoundedIntText(
value=self._nrows,
min=2,
step=1,
description="n_rows:",
disabled=False,
layout=width_entry_text,
)
w_ncols = widgets.BoundedIntText(
value=self._ncols,
min=2,
step=1,
description="n_cols:",
disabled=False,
layout=width_entry_text,
)
w_extra = widgets.Textarea(
value=self._w_extra,
placeholder="Pass extra parameters",
description="params:",
disabled=False,
layout=width_repr_text,
)
w_repr = widgets.RadioButtons(
options=["pdf", "cdf", "ppf"],
value="pdf",
description="",
disabled=False,
layout=width_entry_text,
)
if self._dist_names is None:
default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"]
dist_names = [
"AsymmetricLaplace",
"BetaScaled",
"ChiSquared",
"ExGaussian",
"Exponential",
"Gamma",
"Gumbel",
"HalfNormal",
"HalfStudentT",
"InverseGamma",
"Laplace",
"LogNormal",
"Logistic",
# "LogitNormal", # fails if we add chips at x_value= 1
"Moyal",
"Normal",
"Pareto",
"Rice",
"SkewNormal",
"StudentT",
"Triangular",
"VonMises",
"Wald",
"Weibull",
"BetaBinomial",
"DiscreteWeibull",
"Geometric",
"NegativeBinomial",
"Poisson",
]
else:
default_dist = self._dist_names
dist_names = self._dist_names
w_distributions = widgets.SelectMultiple(
options=dist_names,
value=default_dist,
description="",
disabled=False,
layout=width_distribution_text,
)
w_checkbox_cont = widgets.Checkbox(
value=False, description="Continuous", disabled=False, indent=False
)
w_checkbox_disc = widgets.Checkbox(
value=False, description="Discrete", disabled=False, indent=False
)
w_checkbox_none = widgets.Checkbox(
value=False, description="None", disabled=False, indent=False
)
return {
"w_x_min": w_x_min,
"w_x_max": w_x_max,
"w_ncols": w_ncols,
"w_nrows": w_nrows,
"w_extra": w_extra,
"w_repr": w_repr,
"w_distributions": w_distributions,
"w_checkbox_cont": w_checkbox_cont,
"w_checkbox_disc": w_checkbox_disc,
"w_checkbox_none": w_checkbox_none,
}
def _setup_observers(self):
self._widgets["w_checkbox_none"].observe(self._handle_checkbox_change)
self._widgets["w_checkbox_cont"].observe(self._handle_checkbox_change)
self._widgets["w_checkbox_disc"].observe(self._handle_checkbox_change)
def _update_grid_(_):
self._x_min = self._widgets["w_x_min"].value
self._x_max = self._widgets["w_x_max"].value
self._nrows = self._widgets["w_nrows"].value
self._ncols = self._widgets["w_ncols"].value
self._update_grid()
self._widgets["w_x_min"].observe(_update_grid_)
self._widgets["w_x_max"].observe(_update_grid_)
self._widgets["w_nrows"].observe(_update_grid_)
self._widgets["w_ncols"].observe(_update_grid_)
self._widgets["w_x_min"].observe(self._on_value_change, names="value")
def _on_leave_fig_(_):
self._on_leave_fig()
self._widgets["w_repr"].observe(_on_leave_fig_)
self._widgets["w_distributions"].observe(_on_leave_fig_)
self._widgets["w_extra"].observe(_on_leave_fig_)
def _handle_checkbox_change(self, _):
dist_names = self._handle_checkbox_widget()
self._widgets["w_distributions"].value = dist_names
def _on_value_change(self, change):
new_a = change["new"]
if new_a == self._widgets["w_x_max"].value:
self._widgets["w_x_max"].value = new_a + 1
class _Rectangles:
def __init__(self, fig, coll, nrows, ncols, ax):
self._fig = fig
self._coll = coll
self._nrows = nrows
self._ncols = ncols
self._ax = ax
self._weights = {k: 0 for k in range(0, ncols)}
fig.canvas.mpl_connect("button_press_event", self)
def __call__(self, event):
if event.inaxes == self._ax:
x = event.xdata
y = event.ydata
idx = floor(x)
idy = ceil(y)
if 0 <= idx < self._ncols and 0 <= idy <= self._nrows:
if self._weights[idx] >= idy:
idy -= 1
for row in range(self._nrows):
self._coll[row, idx].set_facecolor("0.8")
self._weights[idx] = idy
for row in range(idy):
self._coll[row, idx].set_facecolor("C1")
self._fig.canvas.draw()