# pylint: disable = protected-access
# pylint: disable = too-many-function-args
# pylint: disable = too-many-nested-blocks
from statistics import mode
import arviz as az
import numpy as np
import pandas as pd
from arviz.plots.backends.matplotlib import create_axes_grid
from arviz.plots.plot_utils import default_grid
from formulae.terms.call import Call
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_string_dtype
from bambi.utils import listify, get_aliased_name
from bambi.plots.utils import get_group_offset, get_unique_levels
[docs]def create_cap_data(model, covariates, grid_n=200, groups_n=5):
"""Create data for a Conditional Adjusted Predictions plot
Parameters
----------
model : bambi.Model
An instance of a Bambi model
covariates : dict
A dictionary of length between one and three.
Keys must be taken from ("horizontal", "color", "panel").
The values indicate the names of variables.
grid_n : int, optional
The number of points used to evaluate the main covariate. Defaults to 200.
groups_n : int, optional
The number of groups to create when the grouping variable is numeric. Groups are based on
equally spaced points. Defaults to 5.
Returns
-------
pandas.DataFrame
The data for the Conditional Adjusted Predictions plot.
Raises
------
ValueError
When the number of covariates is larger than 2.
When either the main or the group covariates are not numeric or categoric.
"""
data = model.data
main = covariates.get("horizontal")
group = covariates.get("color", None)
panel = covariates.get("panel", None)
# Obtain data for main variable
main_values = make_main_values(data[main], grid_n)
main_n = len(main_values)
# If available, obtain groups for grouping variable
if group:
group_values = make_group_values(data[group], groups_n)
group_n = len(group_values)
# If available, obtain groups for panel variable. Same logic than grouping applies
if panel:
panel_values = make_group_values(data[panel], groups_n)
panel_n = len(panel_values)
data_dict = {main: main_values}
if group and not panel:
main_values = np.tile(main_values, group_n)
group_values = np.repeat(group_values, main_n)
data_dict.update({main: main_values, group: group_values})
elif not group and panel:
main_values = np.tile(main_values, panel_n)
panel_values = np.repeat(panel_values, main_n)
data_dict.update({main: main_values, panel: panel_values})
elif group and panel:
if group == panel:
main_values = np.tile(main_values, group_n)
group_values = np.repeat(group_values, main_n)
data_dict.update({main: main_values, group: group_values})
else:
main_values = np.tile(np.tile(main_values, group_n), panel_n)
group_values = np.tile(np.repeat(group_values, main_n), panel_n)
panel_values = np.repeat(panel_values, main_n * group_n)
data_dict.update({main: main_values, group: group_values, panel: panel_values})
# Construct dictionary of terms that are in the model.
# See it includes the terms for _all_ the distributional components, not just the response
terms = {}
for component in model.distributional_components.values():
if component.design.common:
terms.update(component.design.common.terms)
if component.design.group:
terms.update(component.design.group.terms)
# Get default values for each variable in the model
for term in terms.values():
if hasattr(term, "components"):
for component in term.components:
# If the component is a function call, use the argument names
if isinstance(component, Call):
names = [arg.name for arg in component.call.args]
else:
names = [component.name]
for name in names:
if name not in data_dict:
# For numeric predictors, select the mean.
if component.kind == "numeric":
data_dict[name] = np.mean(data[name])
# For categoric predictors, select the most frequent level.
elif component.kind == "categoric":
data_dict[name] = mode(data[name])
cap_data = pd.DataFrame(data_dict)
# Make sure new types are same types than the original columns
for column in cap_data:
cap_data[column] = cap_data[column].astype(data[column].dtype)
return cap_data
[docs]def plot_cap(
model,
idata,
covariates,
target="mean",
pps=False,
use_hdi=True,
hdi_prob=None,
transforms=None,
legend=True,
ax=None,
fig_kwargs=None,
):
"""Plot Conditional Adjusted Predictions
Parameters
----------
model : bambi.Model
The model for which we want to plot the predictions.
idata : arviz.InferenceData
The InferenceData object that contains the samples from the posterior distribution of
the model.
covariates : list or dict
A sequence of between one and three names of variables or a dict of length between one
and three.
If a sequence, the first variable is taken as the main variable,
mapped to the horizontal axis. If present, the second name is a coloring/grouping variable,
and the third is mapped to different plot panels.
If a dictionary, keys must be taken from ("horizontal", "color", "panel") and the values
are the names of the variables.
target : str
Which model parameter to plot. Defaults to 'mean'. Passing a parameter into target only
works when pps is False as the target may not be available in the posterior predictive
distribution.
pps: bool, optional
Whether to plot the posterior predictive samples. Defaults to ``False``.
use_hdi : bool, optional
Whether to compute the highest density interval (defaults to True) or the quantiles.
hdi_prob : float, optional
The probability for the credibility intervals. Must be between 0 and 1. Defaults to 0.94.
Changing the global variable ``az.rcParam["stats.hdi_prob"]`` affects this default.
legend : bool, optional
Whether to automatically include a legend in the plot. Defaults to ``True``.
transforms : dict, optional
Transformations that are applied to each of the variables being plotted. The keys are the
name of the variables, and the values are functions to be applied. Defaults to ``None``.
ax : matplotlib.axes._subplots.AxesSubplot, optional
A matplotlib axes object or a sequence of them. If None, this function instantiates a
new axes object. Defaults to ``None``.
Returns
-------
matplotlib.figure.Figure, matplotlib.axes._subplots.AxesSubplot
A tuple with the figure and the axes.
Raises
------
ValueError
When ``level`` is not within 0 and 1.
When the main covariate is not numeric or categoric.
"""
covariate_kinds = ("horizontal", "color", "panel")
if not isinstance(covariates, dict):
covariates = listify(covariates)
covariates = dict(zip(covariate_kinds, covariates))
else:
assert covariate_kinds[0] in covariates
assert set(covariates).issubset(set(covariate_kinds))
assert 1 <= len(covariates) <= 3
if hdi_prob is None:
hdi_prob = az.rcParams["stats.hdi_prob"]
if not 0 < hdi_prob < 1:
raise ValueError(f"'hdi_prob' must be greater than 0 and smaller than 1. It is {hdi_prob}.")
cap_data = create_cap_data(model, covariates)
if transforms is None:
transforms = {}
response_name = get_aliased_name(model.response_component.response_term)
response_transform = transforms.get(response_name, identity)
if pps:
idata = model.predict(idata, data=cap_data, inplace=False, kind="pps")
y_hat = response_transform(idata.posterior_predictive[response_name])
y_hat_mean = y_hat.mean(("chain", "draw"))
else:
idata = model.predict(idata, data=cap_data, inplace=False)
y_hat = response_transform(idata.posterior[f"{response_name}_{target}"])
y_hat_mean = y_hat.mean(("chain", "draw"))
if use_hdi and pps:
y_hat_bounds = az.hdi(y_hat, hdi_prob)[response_name].T
elif use_hdi:
y_hat_bounds = az.hdi(y_hat, hdi_prob)[f"{response_name}_{target}"].T
else:
lower_bound = round((1 - hdi_prob) / 2, 4)
upper_bound = 1 - lower_bound
y_hat_bounds = y_hat.quantile(q=(lower_bound, upper_bound), dim=("chain", "draw"))
if ax is None:
fig_kwargs = {} if fig_kwargs is None else fig_kwargs
panel = covariates.get("panel", None)
panels_n = len(np.unique(cap_data[panel])) if panel else 1
rows, cols = default_grid(panels_n)
fig, axes = create_axes_grid(panels_n, rows, cols, backend_kwargs=fig_kwargs)
axes = np.atleast_1d(axes)
else:
axes = np.atleast_1d(ax)
if isinstance(axes[0], np.ndarray):
fig = axes[0][0].get_figure()
else:
fig = axes[0].get_figure()
main = covariates.get("horizontal")
if is_numeric_dtype(cap_data[main]):
axes = _plot_cap_numeric(
covariates, cap_data, y_hat_mean, y_hat_bounds, transforms, legend, axes
)
elif is_categorical_dtype(cap_data[main]) or is_string_dtype(cap_data[main]):
axes = _plot_cap_categoric(covariates, cap_data, y_hat_mean, y_hat_bounds, legend, axes)
else:
raise ValueError("Main covariate must be numeric or categoric.")
ylabel = response_name if target == "mean" else target
for ax in axes.ravel(): # pylint: disable = redefined-argument-from-local
ax.set(xlabel=main, ylabel=ylabel)
return fig, axes
def _plot_cap_numeric(covariates, cap_data, y_hat_mean, y_hat_bounds, transforms, legend, axes):
main = covariates.get("horizontal")
transform_main = transforms.get(main, identity)
if len(covariates) == 1:
ax = axes[0]
values_main = transform_main(cap_data[main])
ax.plot(values_main, y_hat_mean, solid_capstyle="butt")
ax.fill_between(values_main, y_hat_bounds[0], y_hat_bounds[1], alpha=0.4)
elif "color" in covariates and not "panel" in covariates:
ax = axes[0]
color = covariates.get("color")
colors = get_unique_levels(cap_data[color])
for i, clr in enumerate(colors):
idx = (cap_data[color] == clr).to_numpy()
values_main = transform_main(cap_data.loc[idx, main])
ax.plot(values_main, y_hat_mean[idx], color=f"C{i}", solid_capstyle="butt")
ax.fill_between(
values_main,
y_hat_bounds[0][idx],
y_hat_bounds[1][idx],
alpha=0.4,
color=f"C{i}",
)
elif not "color" in covariates and "panel" in covariates:
panel = covariates.get("panel")
panels = get_unique_levels(cap_data[panel])
for ax, pnl in zip(axes.ravel(), panels):
idx = (cap_data[panel] == pnl).to_numpy()
values_main = transform_main(cap_data.loc[idx, main])
ax.plot(values_main, y_hat_mean[idx], solid_capstyle="butt")
ax.fill_between(values_main, y_hat_bounds[0][idx], y_hat_bounds[1][idx], alpha=0.4)
ax.set(title=f"{panel} = {pnl}")
elif "color" in covariates and "panel" in covariates:
color = covariates.get("color")
panel = covariates.get("panel")
colors = get_unique_levels(cap_data[color])
panels = get_unique_levels(cap_data[panel])
if color == panel:
for i, (ax, pnl) in enumerate(zip(axes.ravel(), panels)):
idx = (cap_data[panel] == pnl).to_numpy()
values_main = transform_main(cap_data.loc[idx, main])
ax.plot(values_main, y_hat_mean[idx], color=f"C{i}", solid_capstyle="butt")
ax.fill_between(
values_main,
y_hat_bounds[0][idx],
y_hat_bounds[1][idx],
alpha=0.4,
color=f"C{i}",
)
ax.set(title=f"{panel} = {pnl}")
else:
for ax, pnl in zip(axes.ravel(), panels):
for i, clr in enumerate(colors):
idx = ((cap_data[panel] == pnl) & (cap_data[color] == clr)).to_numpy()
values_main = transform_main(cap_data.loc[idx, main])
ax.plot(values_main, y_hat_mean[idx], color=f"C{i}", solid_capstyle="butt")
ax.fill_between(
values_main,
y_hat_bounds[0][idx],
y_hat_bounds[1][idx],
alpha=0.4,
color=f"C{i}",
)
ax.set(title=f"{panel} = {pnl}")
if "color" in covariates and legend:
handles = [
(
Line2D([], [], color=f"C{i}", solid_capstyle="butt"),
Patch(color=f"C{i}", alpha=0.4, lw=1),
)
for i in range(len(colors))
]
for ax in axes.ravel():
ax.legend(
handles, tuple(colors), title=color, handlelength=1.3, handleheight=1, loc="best"
)
return axes
def _plot_cap_categoric(covariates, cap_data, y_hat_mean, y_hat_bounds, legend, axes):
main = covariates.get("horizontal")
main_levels = get_unique_levels(cap_data[main])
main_levels_n = len(main_levels)
idxs_main = np.arange(main_levels_n)
if "color" in covariates:
color = covariates.get("color")
colors = get_unique_levels(cap_data[color])
colors_n = len(colors)
offset_bounds = get_group_offset(colors_n)
colors_offset = np.linspace(-offset_bounds, offset_bounds, colors_n)
if "panel" in covariates:
panel = covariates.get("panel")
panels = get_unique_levels(cap_data[panel])
if len(covariates) == 1:
ax = axes[0]
ax.scatter(idxs_main, y_hat_mean)
ax.vlines(idxs_main, y_hat_bounds[0], y_hat_bounds[1])
elif "color" in covariates and not "panel" in covariates:
ax = axes[0]
for i, clr in enumerate(colors):
idx = (cap_data[color] == clr).to_numpy()
idxs = idxs_main + colors_offset[i]
ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
elif not "color" in covariates and "panel" in covariates:
for ax, pnl in zip(axes.ravel(), panels):
idx = (cap_data[panel] == pnl).to_numpy()
ax.scatter(idxs_main, y_hat_mean[idx])
ax.vlines(idxs_main, y_hat_bounds[0][idx], y_hat_bounds[1][idx])
ax.set(title=f"{panel} = {pnl}")
elif "color" in covariates and "panel" in covariates:
if color == panel:
for i, (ax, pnl) in enumerate(zip(axes.ravel(), panels)):
idx = (cap_data[panel] == pnl).to_numpy()
idxs = idxs_main + colors_offset[i]
ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
ax.set(title=f"{panel} = {pnl}")
else:
for ax, pnl in zip(axes.ravel(), panels):
for i, clr in enumerate(colors):
idx = ((cap_data[panel] == pnl) & (cap_data[color] == clr)).to_numpy()
idxs = idxs_main + colors_offset[i]
ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
ax.set(title=f"{panel} = {pnl}")
if "color" in covariates and legend:
handles = [
Line2D([], [], c=f"C{i}", marker="o", label=level) for i, level in enumerate(colors)
]
for ax in axes.ravel():
ax.legend(handles=handles, title=color, loc="best")
for ax in axes.ravel():
ax.set_xticks(idxs_main)
ax.set_xticklabels(main_levels)
return axes
def identity(x):
return x
def make_main_values(x, grid_n):
if is_numeric_dtype(x):
return np.linspace(np.min(x), np.max(x), grid_n)
elif is_string_dtype(x) or is_categorical_dtype(x):
return np.unique(x)
raise ValueError("Main covariate must be numeric or categoric.")
def make_group_values(x, groups_n):
if is_string_dtype(x) or is_categorical_dtype(x):
return np.unique(x)
elif is_numeric_dtype(x):
return np.quantile(x, np.linspace(0, 1, groups_n))
raise ValueError("Group covariate must be numeric or categoric.")