"""Methods to communicate with PyMC."""
import warnings
from copy import copy
from sys import modules
import numpy as np
try:
from pymc.pytensorf import compile, join_nonshared_inputs
from pymc.util import get_untransformed_name, is_transformed_name
from pytensor import function
from pytensor.graph.traversal import ancestors
from pytensor.tensor import TensorConstant, matrix
except ImportError:
warnings.warn("PyMC not installed. PyMC related functions will not work.")
from preliz.internal.distribution_helper import get_distributions
def back_fitting_pymc(prior, preliz_model, var_info, new_families=None):
"""
Fit the samples from prior into user provided model's prior.
From the perspective of ppe "prior" is actually an approximated posterior
but from the users perspective is its prior.
We need to "backfit" because we can not use arbitrary samples as priors.
We need probability distributions.
"""
new_priors = {}
for rv_name, (_, size, *_) in var_info.items():
if size > 1:
params = []
for i in range(size):
opt_values = prior[rv_name][:, i]
# Not sure how to fit alternative families.
dist = preliz_model[rv_name]
dist._fit_mle(opt_values)
params.append(dist.params)
dist._parametrization(*[np.array(x) for x in zip(*params)])
else:
opt_values = prior[rv_name]
dists = set_families(preliz_model[rv_name], rv_name, new_families)
mle = getattr(modules["preliz.unidimensional"], "mle")
idx, _ = mle(dists, opt_values, plot=False)
dist = dists[idx[0]]
new_priors[rv_name] = dist
return new_priors
def set_families(dist, var, new_families):
dists = [dist]
if new_families is not None:
if new_families == "auto":
alt = [
getattr(modules["preliz.distributions"], d)
for d in ["Normal", "HalfNormal", "Gamma"]
]
dists += [a for a in alt if dist.__class__.__name__ != a.__class__.__name__]
elif isinstance(new_families, list):
dists += new_families
elif isinstance(new_families, dict):
dists += new_families.get(var, [])
return dists
def compile_mllk(model):
"""
Compile the log-likelihood for a pymc model.
The compiled function allow us to condition on both data and parameters.
"""
obs_rvs = model.observed_RVs[0]
old_y_value = model.rvs_to_values[obs_rvs]
new_y_value = obs_rvs.type()
model.rvs_to_values[obs_rvs] = new_y_value
vars_ = model.value_vars
initial_point = model.initial_point()
[logp], raveled_inp = join_nonshared_inputs(
point=initial_point, outputs=[model.datalogp], inputs=vars_
)
rv_logp_fn = compile([raveled_inp, new_y_value], logp)
rv_logp_fn.trust_input = True
def fmodel(params, obs):
return -rv_logp_fn(params, obs).sum()
return fmodel, old_y_value, obs_rvs
def get_initial_guess(model):
"""Get initial guess for optimization routine."""
return np.concatenate([np.ravel(value) for value in model.initial_point().values()])
def extract_preliz_distributions(model):
"""
Extract the corresponding PreliZ distributions from a PyMC model.
Parameters
----------
model : a PyMC model
Returns
-------
preliz_model : a dictionary of RVs names as keys and PreliZ distributions as values
num_draws : the sample size of the observed
"""
all_distributions = [
dist
for dist in modules["preliz.distributions"].__all__
if dist not in ["Truncated", "Censored", "Hurdle", "Mixture"]
]
pymc_to_preliz = dict(
zip([dist.lower() for dist in all_distributions], get_distributions(all_distributions)),
)
preliz_model = {}
for r_v in model.free_RVs:
dist_name = (
r_v.owner.op.name if r_v.owner.op.name else str(r_v.owner.op).split("RV", 1)[0].lower()
)
dist = copy(pymc_to_preliz[dist_name])
preliz_model[r_v.name] = dist
return preliz_model
def retrieve_variable_info(model):
"""Get shape, size, transformation and parents of each free RV in a PyMC model."""
var_info = {}
initial_point = model.initial_point()
for v_var in model.value_vars:
name = v_var.name
rvs = model.values_to_rvs[v_var]
nc_parents = non_constant_parents(rvs, model)
idx_parents = []
if nc_parents:
idx_parents = [model.free_RVs.index(var_) for var_ in nc_parents]
if is_transformed_name(name):
name = get_untransformed_name(name)
x_var = matrix(f"{name}_transformed")
z_var = model.rvs_to_transforms[rvs].backward(x_var)
transformation = function(inputs=[x_var], outputs=z_var)
else:
transformation = None
var_info[name] = (
initial_point[v_var.name].shape,
initial_point[v_var.name].size,
transformation,
idx_parents,
)
num_draws = model.observed_RVs[0].eval().size
return var_info, num_draws
def unravel_projection(prior_array, var_info, iterations):
size = 0
prior_dict = {}
for key, values in var_info.items():
shape, new_size, transformation, _ = values
vector = prior_array[:, size : size + new_size]
if transformation is not None:
vector = transformation(vector)
prior_dict[key] = vector.reshape(iterations, *shape).squeeze()
size += new_size
return prior_dict
def write_pymc_string(new_priors, var_info):
"""
Return a string with the new priors for the PyMC model.
So the user can copy and paste, ideally with none to minimal changes.
"""
header = "with pm.Model() as model:\n"
variables = []
names = list(new_priors.keys())
for key, value in new_priors.items():
idxs = var_info[key][-1]
if idxs:
for i in idxs:
nkey = names[i]
cp_dist = copy(new_priors[nkey])
cp_dist._fit_moments(np.mean(value.mean()), np.mean(value.std()))
dist_name, dist_params = repr(cp_dist).split("(")
size = var_info[nkey][1]
if size > 1:
dist_params = dist_params.split(")")[0]
# fmt: off
variables[i] = f' {nkey:} = pm.{dist_name}("{nkey}", {dist_params}, shape={size})\n' # noqa: E501
# fmt: on
else:
variables[i] = f' {nkey:} = pm.{dist_name}("{nkey}", {dist_params}\n'
else:
dist_name, dist_params = repr(value).split("(")
size = var_info[key][1]
if size > 1:
dist_params = dist_params.split(")")[0]
variables.append(
f' {key:} = pm.{dist_name}("{key}", {dist_params}, shape={size})\n'
)
else:
variables.append(f' {key:} = pm.{dist_name}("{key}", {dist_params}\n')
return "".join([header] + variables)
def non_constant_parents(rvs, model):
"""Find the parents of a variable that are not constant."""
parents = []
for variable in rvs.get_parents()[0].inputs[2:]:
if not isinstance(variable, TensorConstant):
for free_rv in model.free_RVs:
if free_rv in list(ancestors([variable])) and free_rv not in parents:
parents.append(free_rv)
return parents
def if_pymc_get_preliz(dist):
"""Check if dist is a PyMC or Prior distribution and if so convert to PreliZ."""
if dist.__class__.__name__ == "TensorVariable":
dist = from_pymc(dist)
elif dist.__class__.__name__ == "Prior":
dist = from_prior(dist)
return dist
[docs]
def from_pymc(dist):
"""Convert a PyMC distribution to a PreliZ distribution.
Parameters
----------
dist : PyMC distribution
Returns
-------
PreliZ distribution
"""
name = dist.owner.op._print_name[0]
if name == "MultivariateNormal":
name = "MvNormal"
if name == "Censored":
base_dist = dist.owner.inputs[0]
lower = _as_scalar(dist.owner.inputs[1].eval())
upper = _as_scalar(dist.owner.inputs[2].eval())
if np.isnan(lower):
lower = -np.inf
if np.isnan(upper):
upper = np.inf
BaseDist = from_pymc(base_dist)
return modules["preliz.distributions"].Censored(BaseDist, lower=lower, upper=upper)
if "Truncated" in name and name != "TruncatedNormal":
base_dist_name = name.split("Truncated")[1]
base_params = [v.eval() for v in dist.owner.inputs[2:-2]]
lower = _as_scalar(dist.owner.inputs[-2].eval())
upper = _as_scalar(dist.owner.inputs[-1].eval())
if np.isnan(lower):
lower = -np.inf
if np.isnan(upper):
upper = np.inf
BaseDist = getattr(modules["preliz.distributions"], base_dist_name)
return modules["preliz.distributions"].Truncated(
_reparametrize(BaseDist, base_dist_name, base_params), lower=lower, upper=upper
)
elif name == "Mixture":
name_0 = dist.owner.inputs[2].owner.op._print_name[0]
if name_0 == "DiracDelta":
base_node = dist.owner.inputs[-1]
base_name = base_node.owner.op._print_name[0]
base_params = [v.eval() for v in base_node.owner.inputs[2:]]
psi = _nan_to_none(dist.owner.inputs[1].eval())[1]
ZeroInflated = getattr(modules["preliz.distributions"], f"ZeroInflated{base_name}")
if base_name == "NegativeBinomial":
n, p = base_params
mu = n * (1 - p) / p
base_params = [mu, n]
base_params = _nan_to_none(base_params)
return ZeroInflated(psi, *base_params)
else:
components = dist.owner.inputs[2:]
weights = _nan_to_none(dist.owner.inputs[1].eval())
PreliZ_components = [from_pymc(comp) for comp in components]
return modules["preliz.distributions"].Mixture(PreliZ_components, weights=weights)
elif name == "Hurdle":
base_type_name = dist.owner.inputs[-1].owner.op._print_name[0].replace("Truncated", "")
psi = _nan_to_none(dist.owner.inputs[1].eval())[-1]
base_params = [v.eval() for v in dist.owner.inputs[-1].owner.inputs[2:]]
BaseDist = getattr(modules["preliz.distributions"], base_type_name)
return getattr(modules["preliz.distributions"], "Hurdle")(
_reparametrize(BaseDist, base_type_name, base_params), psi
)
else:
if name == "HalfNormal":
params_inputs = [v.eval() for v in dist.owner.inputs[3:]]
elif name == "LogitNormal":
params_inputs = [v.eval() for v in dist.owner.inputs[1:]]
else:
params_inputs = [v.eval() for v in dist.owner.inputs[2:]]
params_inputs = [
p for p in params_inputs if isinstance(p, (int, float, np.number, np.ndarray))
]
try:
Dist = getattr(modules["preliz.distributions"], name)
except AttributeError:
raise NotImplementedError(f"No PreliZ distribution named {name}")
return _reparametrize(Dist, name, params_inputs)
[docs]
def from_prior(prior):
"""Convert a Prior distribution (from pymc-extras).
Parameters
----------
dist : Prior distribution
Returns
-------
PreliZ distribution
"""
dist_name = prior.distribution
kwargs = prior.to_dict().get("kwargs", {})
if dist_name in ["Truncated", "Censored"]:
dist_arg = kwargs["dist"]
base = getattr(modules["preliz.distributions"], dist_arg["dist"])(
**dist_arg.get("kwargs", {})
)
return getattr(modules["preliz.distributions"], dist_name)(
base, lower=kwargs.get("lower"), upper=kwargs.get("upper")
)
if dist_name.startswith("Hurdle"):
base_name = dist_name.replace("Hurdle", "")
base_kwargs = {k: v for k, v in kwargs.items() if k != "psi"}
base = getattr(modules["preliz.distributions"], base_name)(**base_kwargs)
return getattr(modules["preliz.distributions"], "Hurdle")(base, psi=kwargs["psi"])
if dist_name == "Mixture":
comp_dists = [from_prior(d) for d in kwargs["comp_dists"]]
w = kwargs["w"]
return getattr(modules["preliz.distributions"], "Mixture")(comp_dists, weights=w)
if hasattr(modules["preliz.distributions"], dist_name):
return getattr(modules["preliz.distributions"], dist_name)(**kwargs)
raise ValueError(f"Unknown PreliZ distribution: {dist_name}")
def _as_scalar(x):
x = np.asarray(x)
return x.item() if x.shape == () or x.size == 1 else x
def _reparametrize(Dist, name, params_inputs):
if name == "AsymmetricLaplace":
b, kappa, mu = _nan_to_none(params_inputs)
return Dist(mu=mu, b=b, kappa=kappa)
if name == "Exponential":
scale = _nan_to_none(params_inputs)[0]
if scale is not None:
lam_ = 1 / scale
else:
lam_ = None
return Dist(lam_)
if name == "Gamma":
alpha, inv_beta = _nan_to_none(params_inputs)
if inv_beta is not None:
beta = 1 / inv_beta
else:
beta = None
return Dist(alpha=alpha, beta=beta)
if name == "Rice":
b, sigma = params_inputs
nu, sigma = _nan_to_none((b * sigma, sigma))
return Dist(nu=nu, sigma=sigma)
if name == "SkewStudentT":
a, b, mu, sigma = _nan_to_none(params_inputs)
return Dist(mu=mu, sigma=sigma, a=a, b=b)
if name == "Wald":
mu, lam, _ = _nan_to_none(params_inputs)
return Dist(mu=mu, lam=lam)
if name == "BetaBinomial":
n, alpha, beta = _nan_to_none(params_inputs)
return Dist(alpha=alpha, beta=beta, n=n)
if name == "NegativeBinomial":
n, p = _nan_to_none(params_inputs)
if p is not None and n is not None:
mu = n * (1 - p) / p
else:
mu = None
return Dist(mu=mu, alpha=n)
if name == "HyperGeometric":
good, bad, n = _nan_to_none(params_inputs)
if good is not None and bad is not None:
N = good + bad
else:
N = None
return Dist(N=N, k=good, n=n)
return Dist(*_nan_to_none(params_inputs))
def _nan_to_none(params):
if np.isscalar(params):
return None if np.isnan(params) else params
result = []
for p in params:
arr = np.asarray(p)
if arr.size == 1:
val = arr.item()
result.append(None if np.isnan(val) else val)
else:
mask = np.isnan(arr)
if np.any(mask):
out = arr.astype(object)
out[mask] = None
result.append(out)
else:
result.append(p)
return result