# pylint: disable = protected-access
# pylint: disable = too-many-function-args
# pylint: disable = too-many-nested-blocks
from dataclasses import dataclass, field
import itertools
from typing import Dict, Union
import arviz as az
import numpy as np
import pandas as pd
import xarray as xr
from bambi.models import Model
from bambi.plots.create_data import create_cap_data, create_comparisons_data
from bambi.plots.utils import average_over, ConditionalInfo, ContrastInfo, enforce_dtypes, identity
from bambi.utils import get_aliased_name, listify
@dataclass
class ResponseInfo:
name: str
target: Union[str, None] = None
lower_bound: float = 0.03
upper_bound: float = 0.97
name_target: str = field(init=False)
name_obs: str = field(init=False)
lower_bound_name: str = field(init=False)
upper_bound_name: str = field(init=False)
def __post_init__(self):
"""
Assigns commonly used f-strings for indexing and column names as attributes.
"""
if self.target is None:
self.name_target = self.name
else:
self.name_target = f"{self.name}_{self.target}"
self.name_obs = f"{self.name}_obs"
self.lower_bound_name = f"lower_{self.lower_bound * 100}%"
self.upper_bound_name = f"upper_{self.upper_bound * 100}%"
[docs]def predictions(
model: Model,
idata: az.InferenceData,
covariates: Union[str, dict, list],
target: str = "mean",
pps: bool = False,
use_hdi: bool = True,
prob=None,
transforms=None,
) -> pd.DataFrame:
"""Compute Conditional Adjusted Predictions
Parameters
----------
model : bambi.Model
The model for which we want to plot the predictions.
idata : arviz.InferenceData
The InferenceData object that contains the samples from the posterior distribution of
the model.
covariates : list or dict
A sequence of between one and three names of variables or a dict of length between one
and three.
If a sequence, the first variable is taken as the main variable and is mapped to the
horizontal axis. If present, the second name is a coloring/grouping variable,
and the third is mapped to different plot panels.
If a dictionary, keys must be taken from ("main", "group", "panel") and the values
are the names of the variables.
target : str
Which model parameter to plot. Defaults to 'mean'. Passing a parameter into target only
works when pps is False as the target may not be available in the posterior predictive
distribution.
pps: bool, optional
Whether to plot the posterior predictive samples. Defaults to ``False``.
use_hdi : bool, optional
Whether to compute the highest density interval (defaults to True) or the quantiles.
prob : float, optional
The probability for the credibility intervals. Must be between 0 and 1. Defaults to 0.94.
Changing the global variable ``az.rcParam["stats.hdi_prob"]`` affects this default.
transforms : dict, optional
Transformations that are applied to each of the variables being plotted. The keys are the
name of the variables, and the values are functions to be applied. Defaults to ``None``.
Returns
-------
cap_data : pandas.DataFrame
A DataFrame with the ``create_cap_data`` and model predictions.
Raises
------
ValueError
If ``pps`` is ``True`` and ``target`` is not ``"mean"``.
If passed ``covariates`` is not in correct key, value format.
If length of ``covariates`` is not between 1 and 3.
"""
if pps and target != "mean":
raise ValueError("When passing 'pps=True', target must be 'mean'")
covariate_kinds = ("main", "group", "panel")
if not isinstance(covariates, dict):
covariates = listify(covariates)
covariates = dict(zip(covariate_kinds, covariates))
else:
assert covariate_kinds[0] in covariates
assert set(covariates).issubset(set(covariate_kinds))
assert 1 <= len(covariates) <= 3
if transforms is None:
transforms = {}
if prob is None:
prob = az.rcParams["stats.hdi_prob"]
if not 0 < prob < 1:
raise ValueError(f"'prob' must be greater than 0 and smaller than 1. It is {prob}.")
cap_data = create_cap_data(model, covariates)
if target != "mean":
component = model.components[target]
if component.alias:
# use only the aliased name (without appended target)
response_name = get_aliased_name(component)
target = None
else:
# use the default response "y" and append target
response_name = get_aliased_name(model.response_component.response_term)
else:
response_name = get_aliased_name(model.response_component.response_term)
response = ResponseInfo(response_name, target)
response_transform = transforms.get(response_name, identity)
if pps:
idata = model.predict(idata, data=cap_data, inplace=False, kind="pps")
y_hat = response_transform(idata.posterior_predictive[response.name])
y_hat_mean = y_hat.mean(("chain", "draw"))
else:
idata = model.predict(idata, data=cap_data, inplace=False)
y_hat = response_transform(idata.posterior[response.name_target])
y_hat_mean = y_hat.mean(("chain", "draw"))
if use_hdi and pps:
y_hat_bounds = az.hdi(y_hat, prob)[response.name].T
elif use_hdi:
y_hat_bounds = az.hdi(y_hat, prob)[response.name_target].T
else:
lower_bound = round((1 - prob) / 2, 4)
upper_bound = 1 - lower_bound
y_hat_bounds = y_hat.quantile(q=(lower_bound, upper_bound), dim=("chain", "draw"))
lower_bound = round((1 - prob) / 2, 4)
upper_bound = 1 - lower_bound
response.lower_bound, response.upper_bound = lower_bound, upper_bound
cap_data["estimate"] = y_hat_mean
cap_data[response.lower_bound_name] = y_hat_bounds[0]
cap_data[response.upper_bound_name] = y_hat_bounds[1]
return cap_data
@dataclass
class ContrastEstimate:
comparison: Dict[str, xr.DataArray]
hdi: Dict[str, xr.Dataset]
[docs]def comparisons(
model: Model,
idata: az.InferenceData,
contrast: Union[str, dict, list],
conditional: Union[str, dict, list, None] = None,
average_by: Union[str, list, bool, None] = None,
comparison_type: str = "diff",
use_hdi: bool = True,
prob=None,
transforms=None,
) -> pd.DataFrame:
"""Compute Conditional Adjusted Comparisons
Parameters
----------
model : bambi.Model
The model for which we want to plot the predictions.
idata : arviz.InferenceData
The InferenceData object that contains the samples from the posterior distribution of
the model.
contrast : str, dict, list
The predictor name whose contrast we would like to compare.
conditional : str, dict, list
The covariates we would like to condition on.
average_by: str, list, bool, optional
The covariates we would like to average by. The passed covariate(s) will marginalize
over the other covariates in the model. If True, it averages over all covariates
in the model to obtain the average estimate. Defaults to ``None``.
comparison_type : str, optional
The type of comparison to plot. Defaults to 'diff'.
use_hdi : bool, optional
Whether to compute the highest density interval (defaults to True) or the quantiles.
prob : float, optional
The probability for the credibility intervals. Must be between 0 and 1. Defaults to 0.94.
Changing the global variable ``az.rcParam["stats.hdi_prob"]`` affects this default.
transforms : dict, optional
Transformations that are applied to each of the variables being plotted. The keys are the
name of the variables, and the values are functions to be applied. Defaults to ``None``.
Returns
-------
pandas.DataFrame
A dataframe with the comparison values, highest density interval, contrast name,
contrast value, and conditional values.
Raises
------
ValueError
If length of ``contrast`` is greater than 1.
If ``contrast`` is not a string, dictionary, or list.
If ``comparison_type`` is not 'diff' or 'ratio'.
If ``prob`` is not > 0 and < 1.
"""
if not isinstance(contrast, (dict, list, str)):
raise ValueError("'contrast' must be a string, dictionary, or list.")
if isinstance(contrast, (dict, list)):
if len(contrast) > 1:
raise ValueError(
f"Only one contrast predictor can be passed. {len(contrast)} were passed."
)
if comparison_type not in ("diff", "ratio"):
raise ValueError("'comparison_type' must be 'diff' or 'ratio'")
if prob is None:
prob = az.rcParams["stats.hdi_prob"]
if not 0 < prob < 1:
raise ValueError(f"'prob' must be greater than 0 and smaller than 1. It is {prob}.")
comparison_functions = {"diff": lambda x, y: x - y, "ratio": lambda x, y: x / y}
lower_bound = round((1 - prob) / 2, 4)
upper_bound = 1 - lower_bound
contrast_info = ContrastInfo(model, contrast)
conditional_info = ConditionalInfo(model, conditional)
# 'comparisons' should not be restricted to ("main", "group", "panel")
comparisons_df = create_comparisons_data(
conditional_info, contrast_info, user_passed=conditional_info.user_passed
)
if transforms is None:
transforms = {}
response_name = get_aliased_name(model.response_component.response_term)
response = ResponseInfo(
response_name, target="mean", lower_bound=lower_bound, upper_bound=upper_bound
)
# perform predictions on new data
idata = model.predict(idata, data=comparisons_df, inplace=False)
def _compute_contrast_estimate(
contrast: ContrastInfo,
response: ResponseInfo,
comparisons_df: pd.DataFrame,
idata: az.InferenceData,
) -> ContrastEstimate:
"""
Computes the contrast comparison estimate and highest density interval
for a given contrast and response by first subsetting posterior draws
using a contrast mask. Then, pairwise comparisons are computed for the
contrast values. Finally, the mean comparison and lower/upper bounds
are computed for each pairwise comparison.
"""
function = comparison_functions[comparison_type]
draws = {}
for idx, val in enumerate(contrast.values):
mask = np.array(comparisons_df[contrast.name] == contrast.values[idx])
select_draw = idata.posterior[response.name_target].sel({response.name_obs: mask})
select_draw = select_draw.assign_coords(
{response.name_obs: np.arange(len(select_draw.coords[response.name_obs]))}
)
draws[val] = select_draw
pairwise_contrasts = list(itertools.combinations(contrast.values, 2))
comparison_mean = {}
comparison_bounds = {}
for idx, pair in enumerate(pairwise_contrasts):
comparison_estimate = function(draws[pair[1]], draws[pair[0]])
comparison_mean[pair] = comparison_estimate.mean(("chain", "draw"))
if use_hdi:
comparison_bounds[pair] = az.hdi(comparison_estimate, prob)
else:
comparison_bounds[pair] = comparison_estimate.quantile(
q=(response.lower_bound, response.upper_bound), dim=("chain", "draw")
)
return ContrastEstimate(comparison_mean, comparison_bounds)
def _build_contrasts_df(
contrast: ContrastInfo,
condition: ConditionalInfo,
response: ResponseInfo,
comparisons_df: pd.DataFrame,
idata: az.InferenceData,
average_by,
) -> pd.DataFrame:
"""
Builds a dataframe with the comparison values and lower / upper bounds from
``_compute_contrast_estimate`` along with the contrast name, contrast value,
and conditional values.
"""
contrast_estimate = _compute_contrast_estimate(contrast, response, comparisons_df, idata)
# if two contrast values, then can drop duplicates to build contrast_df
if len(contrast.values) < 3:
if not any(condition.covariates.values()):
contrast_df = model.data[comparisons_df.columns].drop(columns=contrast.name)
num_rows = contrast_df.shape[0]
contrast_df.insert(0, "term", contrast.name)
contrast_df.insert(
1, "contrast", list(np.tile(contrast.values, num_rows).reshape(num_rows, 2))
)
contrast_df["estimate"] = contrast_estimate.comparison[
tuple(contrast.values)
].to_numpy()
else:
contrast_df = comparisons_df.drop_duplicates(
list(condition.covariates.values())
).reset_index(drop=True)
contrast_df = contrast_df.drop(columns=contrast.name)
num_rows = contrast_df.shape[0]
contrast_df.insert(0, "term", contrast.name)
contrast_df.insert(
1, "contrast", list(np.tile(contrast.values, num_rows).reshape(num_rows, 2))
)
contrast_df["estimate"] = contrast_estimate.comparison[
tuple(contrast.values)
].to_numpy()
if use_hdi:
contrast_df[response.lower_bound_name] = (
contrast_estimate.hdi[tuple(contrast.values)][response.name_target]
.sel(hdi="lower")
.values
)
contrast_df[response.upper_bound_name] = (
contrast_estimate.hdi[tuple(contrast.values)][response.name_target]
.sel(hdi="higher")
.values
)
else:
contrast_df[response.lower_bound_name] = contrast_estimate.hdi[
tuple(contrast.values)
].sel(quantile=lower_bound)
contrast_df[response.upper_bound_name] = contrast_estimate.hdi[
tuple(contrast.values)
].sel(quantile=upper_bound)
# if > 2 contrast values, then need the full dataframe to build contrast_df
elif len(contrast.values) >= 3:
contrast_keys = [list(elem) for elem in list(contrast_estimate.comparison.keys())]
covariate_cols = comparisons_df.drop(columns=contrast.name).columns
covariate_vals = (
comparisons_df.drop(columns=contrast.name).drop_duplicates().reset_index(drop=True)
).values
covariate_vals = np.tile(np.transpose(covariate_vals), len(contrast.values))
contrast_df = (
pd.DataFrame(contrast_estimate.comparison)
.unstack()
.reset_index()
.rename(columns={0: "estimate"})
)
# this hardcoded subset will not work for cross-contrasts
contrast_df.insert(0, "term", contrast.name)
contrast_df.insert(
1, "contrast", tuple(zip(contrast_df["level_0"], contrast_df["level_1"]))
)
contrast_df = contrast_df.drop(["level_0", "level_1", "level_2"], axis=1)
lower = []
upper = []
for pair in contrast_keys:
if use_hdi:
lower.append(
(
contrast_estimate.hdi[tuple(pair)][response.name_target]
.sel(hdi="lower")
.values
)
)
upper.append(
(
contrast_estimate.hdi[tuple(pair)][response.name_target]
.sel(hdi="higher")
.values
)
)
else:
lower.append(contrast_estimate.hdi[tuple(pair)].sel(quantile=lower_bound))
upper.append(contrast_estimate.hdi[tuple(pair)].sel(quantile=upper_bound))
contrast_df[covariate_cols] = np.transpose(covariate_vals)
contrast_df[response.lower_bound_name] = np.array(lower).flatten()
contrast_df[response.upper_bound_name] = np.array(upper).flatten()
contrast_df.insert(
len(contrast_df.columns) - 3, "estimate", contrast_df.pop("estimate")
)
contrast_df = enforce_dtypes(model.data, contrast_df)
contrast_df["contrast"] = contrast_df["contrast"].apply(tuple)
if average_by:
if average_by is True:
contrast_df_avg = average_over(contrast_df, None)
contrast_df_avg.insert(0, "term", contrast.name)
contrast_df_avg.insert(
1,
"contrast",
np.tile(contrast_df["contrast"].drop_duplicates(), len(contrast_df_avg)),
)
else:
contrast_df_avg = average_over(contrast_df, average_by)
contrast_df_avg.insert(0, "term", contrast.name)
contrast_df_avg.insert(
1,
"contrast",
np.tile(contrast_df["contrast"].drop_duplicates(), len(contrast_df_avg)),
)
return contrast_df_avg.reset_index(drop=True)
else:
return contrast_df.reset_index(drop=True)
return _build_contrasts_df(
contrast_info,
conditional_info,
response,
comparisons_df,
idata,
average_by,
)