Source code for preliz.predictive.ppa

"""Prior predictive check assistant."""

import ast
import warnings
from random import shuffle

try:
    import ipywidgets as widgets
except ImportError:
    pass
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import KDTree

from preliz.distributions import Normal
from preliz.distributions.distributions import Distribution
from preliz.internal.plot_helper import (
    check_inside_notebook,
    plot_pp_mean,
    plot_pp_samples,
)
from preliz.internal.predictive_helper import back_fitting_ppa, select_prior_samples
from preliz.ppls.agnostic import from_preliz, get_prior_pp_samples
from preliz.ppls.bambi_io import from_bambi


[docs] def ppa( fmodel, draws=2000, references=0, boundaries=(-np.inf, np.inf), target=None, new_families=True, engine="preliz", ): """ Prior predictive check assistant. This is an experimental method under development, use with caution. Parameters ---------- model : PreliZ model draws : int Number of draws from the prior and prior predictive distribution references : int, float, list, tuple or dictionary Value(s) used as reference points representing prior knowledge. For example expected values or values that are considered extreme. Use a dictionary for labeled references. boundaries : tuple Hard boundaries (lower, upper). Posterior predictive samples with values outside these boundaries will be excluded from the analysis. target : tuple or PreliZ distribtuion Target distribution. The first shown distributions will be selected to be as close as possible to `target`. Available options are, a PreliZ distribution or a 2-tuple with the first element representing the mean and the second the standard deviation. new_families : bool If True, the method will return the best fitting distribution from a set of common distributions engine : str Library used to define the model. Either `preliz` or `bambi`. Defaults to `preliz` """ check_inside_notebook(need_widget=True) warnings.warn(""""This is an experimental method under development, use with caution.""") filter_dists = FilterDistribution( fmodel, draws, references, boundaries, target, new_families, engine ) filter_dists() output = widgets.Output() with output: references_widget = widgets.Text( value=str(references), placeholder="Int, Float, tuple or dict", description="references: ", disabled=False, layout=widgets.Layout(width="230px", margin="0 20px 0 0"), ) button_carry_on = widgets.Button(description="carry on") button_return_prior = widgets.Button(description="return prior") radio_buttons_kind = widgets.RadioButtons( options=["pdf", "hist", "ecdf"], value="pdf", description=" ", disabled=False, ) check_button_sharex = widgets.Checkbox( value=True, description="sharex", disabled=False, indent=False ) button_carry_on.on_click( lambda event: filter_dists.carry_on(radio_buttons_kind.value, check_button_sharex.value) ) button_return_prior.on_click(lambda event: filter_dists.on_return_prior()) def kind_(_): kind = radio_buttons_kind.value try: filter_dists.references = ast.literal_eval(references_widget.value) except (ValueError, SyntaxError): filter_dists.references = None plot_pp_samples( filter_dists.pp_samples, filter_dists.display_pp_idxs, filter_dists.references, kind, check_button_sharex.value, filter_dists.fig, ) plot_pp_mean( filter_dists.pp_samples, list(filter_dists.selected), filter_dists.references, kind, filter_dists.fig_pp_mean, ) references_widget.observe(kind_, names=["value"]) radio_buttons_kind.observe(kind_, names=["value"]) check_button_sharex.observe(kind_, names=["value"]) def click(event): if event.inaxes is not None: if event.inaxes not in filter_dists.clicked: filter_dists.clicked.append(event.inaxes) else: filter_dists.clicked.remove(event.inaxes) plt.setp(event.inaxes.spines.values(), color="k", lw=1) for ax in filter_dists.clicked: plt.setp(ax.spines.values(), color="C1", lw=3) filter_dists.fig.canvas.draw() filter_dists.fig.canvas.mpl_connect("button_press_event", click) controls = widgets.VBox([button_carry_on, button_return_prior]) plot_combine = widgets.VBox([radio_buttons_kind, check_button_sharex]) display(widgets.HBox([references_widget, plot_combine, controls, output])) # noqa:F821
class FilterDistribution: def __init__(self, fmodel, draws, references, boundaries, target, new_families, engine): self.fmodel = fmodel self.source = "" # string representation of the model self.draws = draws self.references = references self.boundaries = boundaries self.target = target self.new_families = new_families self.engine = engine self.pp_samples = None # prior predictive samples self.prior_samples = None # prior samples used for backfitting self.display_pp_idxs = None # indices of the pp_samples to be displayed self.pp_octiles = None # octiles computed from pp_samples self.ref_octiles = None # octiles computed from the target distribution self.kdt = None # KDTree used to find similar distributions self.model = None # parsed model used for backfitting self.clicked = [] # axes clicked by the user self.choices = [] # indices of the pp_samples selected by the user and not yet used to # find new distributions, this list can increase or decrease in size self.selected = set() # indices of all the pp_samples selected by the user + machine # this set can only increase in size self.distances = {} # distances between as selected distribution and its nearest neighbor self.shown = set() # keep track to avoid showing the same distribution twice. self.fig = None # figure used to display the pp_samples self.fig_pp_mean = None # figure used to display the mean of the pp_samples self.axes = None # axes used to display the pp_samples def __call__(self): if self.engine == "preliz": variables, self.model = from_preliz(self.fmodel) elif self.engine == "bambi": self.fmodel, variables, self.model = from_bambi(self.fmodel, self.draws) self.pp_samples, self.prior_samples = get_prior_pp_samples( self.fmodel, variables, self.draws, self.engine ) if self.target is not None: self.add_target_dist() self.pp_octiles, self.kdt = self.compute_octiles() self.display_pp_idxs = self.initialize_subsamples(self.target) self.fig, self.axes = plot_pp_samples( self.pp_samples, self.display_pp_idxs, self.references, ) self.fig_pp_mean = plot_pp_mean(self.pp_samples, self.selected, self.references) def add_target_dist(self): if isinstance(self.target, tuple): ref_sample = Normal(*self.target).rvs(self.pp_samples.shape[1]) elif isinstance(self.target, Distribution): ref_sample = self.target.rvs(self.pp_samples.shape[1]) self.ref_octiles = np.quantile(ref_sample, [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875]) def compute_octiles(self): """ Compute the octiles for the prior predictive samples. This is used to find similar distributions using a KDTree. We have empirically found that octiles are a good choice, but this could be the consequence of limited testing. """ pp_octiles = np.quantile( self.pp_samples, [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875], axis=1 ).T kdt = KDTree(pp_octiles) return pp_octiles, kdt def initialize_subsamples(self, target): """ Initialize the subsamples to be displayed. If `target` is None, we search for the farthest_neighbor (this increases diversity) otherwise we search for the nearest_neighbor of target The initialization takes into account the boundaries provided by the user. Updates the `shown` set (already shown pp_samples) and if `target` is None also updates the `distances` dictionary. """ pp_idxs_to_display = [] shown_list = [] for idx, sample in enumerate(self.pp_samples): if np.min(sample) < self.boundaries[0] or np.max(sample) > self.boundaries[1]: shown_list.append(idx) self.shown.update(shown_list) rng = np.random.default_rng(235) # If we have not seen all the samples yet, we collect more if len(self.shown) != self.draws: if target is None: new = rng.choice(list(set(range(0, len(self.pp_octiles))) - self.shown)) pp_idxs_to_display.append(new) for _ in range(8): farthest_neighbor = self.draws while new in pp_idxs_to_display or new in self.shown: _, new = self.kdt.query( self.pp_octiles[pp_idxs_to_display[-1]], [farthest_neighbor], workers=-1 ) new = new.item() farthest_neighbor -= 1 # Missing neighbors are indicated with index==sample_size if new != self.draws: pp_idxs_to_display.append(new) else: new = -1 pp_idxs_to_display.append(new) for _ in range(9): nearest_neighbor = 2 while new in pp_idxs_to_display: distance, new = self.kdt.query( self.ref_octiles, [nearest_neighbor], workers=-1 ) new = new.item() nearest_neighbor += 1 if new != self.draws: pp_idxs_to_display.append(new) self.distances[new] = distance.item() pp_idxs_to_display = pp_idxs_to_display[1:] self.shown.update(pp_idxs_to_display) return pp_idxs_to_display def keep_sampling(self): """ Find distribution similar to the ones in `choices`, but not already shown. If `choices` is empty return an empty selection. """ if self.choices: new = self.choices.pop(0) samples = [new] for _ in range(9): nearest_neighbor = 2 while new in samples or new in self.shown: distance, new = self.kdt.query( self.pp_octiles[samples[-1]], [nearest_neighbor], workers=-1 ) new = new.item() nearest_neighbor += 1 # Missing neighbors are indicated with index==self.draws if new != self.draws: self.distances[new] = distance.item() samples.append(new) self.shown.update(samples[1:]) return samples[1:] else: return [] def collect_more_samples(self): """ Automatically extend the set of user selected distributions. If the user has selected at least two distributions we automatically extend the selection by adding all the distributions that are close to the selected ones. To do so we use compute a max_dist, which is the trimmed mean of the distances between the already selected distributions. The trimmed mean is computed by discarding the lower and upper 10%. """ selected_distances = np.array([v for k, v in self.distances.items() if k in self.selected]) if len(selected_distances) > 2: q_r = np.quantile(selected_distances, [0.1, 0.9]) max_dist = np.mean( selected_distances[(selected_distances > q_r[0]) & (selected_distances < q_r[1])] ) upper = self.draws else: max_dist = np.inf upper = 3 _, new = self.kdt.query( self.pp_octiles[list(self.selected)], range(2, upper), distance_upper_bound=max_dist, workers=-1, ) new = new[new < self.draws].tolist() if np.any(np.isfinite(self.boundaries)): new_ = [] for n_s in new: sample = self.pp_samples[n_s] if np.min(sample) > self.boundaries[0] and np.max(sample) < self.boundaries[1]: new_.append(n_s) new = new_ if new: self.selected.update(new) self.shown.update(new) def carry_on(self, kind, sharex): """Collect user input and update the plot.""" self.fig.suptitle("") if self.clicked: self.choices.extend([int(ax.get_title()) for ax in self.clicked]) shuffle(self.choices) self.selected.update(self.choices) self.collect_more_samples() for ax in self.clicked: plt.setp(ax.spines.values(), color="k", lw=1) for ax in self.axes: ax.cla() for ax in list(self.clicked): self.clicked.remove(ax) self.display_pp_idxs = self.keep_sampling() # if there is nothing to show initialize a new set of samples if not self.display_pp_idxs: self.display_pp_idxs = self.initialize_subsamples(None) plot_pp_mean(self.pp_samples, list(self.selected), self.references, kind, self.fig_pp_mean) if self.display_pp_idxs: plot_pp_samples( self.pp_samples, self.display_pp_idxs, self.references, kind, sharex, self.fig ) else: # Instead of showing this message, we should resample. self.fig.clf() self.fig.suptitle("We have seen all the samples", fontsize=16) self.fig.canvas.draw() def on_return_prior(self): selected = list(self.selected) if len(selected) > 4: subsample = select_prior_samples(selected, self.prior_samples, self.model) string, _ = back_fitting_ppa(self.model, subsample, new_families=self.new_families) self.fig.clf() plt.text(0.05, 0.5, string, fontsize=14) plt.yticks([]) plt.xticks([]) else: self.fig.suptitle("Please select more distributions", fontsize=16) self.fig.canvas.draw()