Plot Conditional Adjusted Predictions#

This notebook shows how to use, and the capabilities, of the plot_cap function. The plot_cap function is a part of Bambi’s sub-package plots that features a set of tools used to interpret complex regression models that is inspired by the R package marginaleffects.

Interpreting Generalized Linear Models#

The purpose of the generalized linear model (GLM) is to unify the approaches needed to analyze data for which either: (1) the assumption of a linear relation between \(x\) and \(y\), or (2) the assumption of normal variation is not appropriate. GLMs are typically specified in three stages: 1. the linear predictor \(\eta = X\beta\) where \(X\) is an \(n\) x \(p\) matrix of explanatory variables. 2. the link function \(g(\cdot)\) that relates the linear predictor to the mean of the outcome variable \(\mu = g^{-1}(\eta) = g^{-1}(X\beta)\) 3. the random component specifying the distribution of the outcome variable \(y\) with mean \(\mathbb{E}(y|X) = \mu\).

Based on these three specifications, the mean of the distribution of \(y\), given \(X\), is determined by \(X\beta: \mathbb{E}(y|X) = g^{-1}(X\beta)\).

GLMs are a broad family of models where the output \(y\) is typically assumed to follow an exponential family distribution, e.g., Binomial, Poisson, Gamma, Exponential, and Normal. The job of the link function is to map the linear space of the model \(X\beta\) onto the non-linear space of a parameter like \(\mu\). Commonly used link function are the logit and log link. Also known as the canonical link functions. This brief introduction to GLMs is not meant to be exhuastive, and another good starting point is the Bambi Basic Building Blocks example.

Due to the link function, there are typically three quantities of interest to interpret in a GLM: 1. the linear predictor \(\eta\) 2. the mean \(\mu = g^{-1}(\eta)\) 3. the response variable \(Y \sim \mathcal{D}(\mu, \theta)\) where \(\mu\) is the mean parameter and \(\theta\) is (possibly) a vector that contains all the other “nuissance” parameters of the distribution.

As modelers, we are usually more interested in interpreting (2) and (3). However, \(\mu\) is not always on the same scale of the response variable and can be more difficult to interpret. Rather, the response scale is a more interpretable scale. Additionally, it is often the case that modelers would like to analyze how a model parameter varies across a range of explanatory variable values. To achieve such an analysis, Bambi has taken inspiration from the R package marginaleffects, and implemented a plot_cap function that plots the conditional adjusted predictions to aid in the interpretation of GLMs. Below, it is briefly discussed what are conditionally adjusted predictions, how they are computed, and ultimately how to use the plot_cap function.

Conditionally Adjusted Predictions#

Adjusted predictions refers to the outcome predicted by a fitted model on a specified scale for a given combination of values of the predictor variables, such as their observed values, their means, or some user specified grid of values. The specification of the scale to make the predictions, the link or response scale, refers to the scale used to estimate the model. In normal linear regression, the link scale and the response scale are identical, and therefore, the adjusted prediction is expressed as the mean value of the response variable at the given values of the predictor variables. On the other hand, a logistic regression’s link and response scale are not identical. An adjusted prediction on the link scale will be represented as the log-odds of a successful response given values of the predictor variables. Whereas an adjusted prediction on the response scale gives the probability that the response variable equals 1. The conditional part of conditionally adjusted predictions represents the specific predictor(s) and its values we would like to condition on when plotting predictions.

Computing Adjusted Predictions#

The objective of plotting conditional adjusted predictions is to visualize how a parameter of the (conditional) response distribution varies as a function of (some) interpolated explanatory variables. This is done by holding all other explanatory variables constant at some specified value, a reference grid, that may or may not correspond to actual observations in the dataset used to fit the model. By default, the plot_cap function uses a grid of 200 equally spaced values between the minimum and maximum values of the specified explanatory variable as the reference grid.

The plot_cap function uses the fitted model to then compute the predicted values of the model parameter at each value of the reference grid. The plot_cap function then uses these predictions to plot the model parameter as a function of (some) explanatory variable.

[9]:
import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from bambi.plots import plot_cap

%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Gaussian Linear Model#

For the first demonstration, we will use a Gaussian linear regression model with the mtcars dataset to better understand the plot_cap function and its arguments. The mtcars dataset was extracted from the 1974 Motor Trend US magazine, and comprises fuel consumption and 10 aspects of automobile design and performance for 32 automobiles (1973–74 models). The following is a brief description of the variables in the dataset:

  • mpg: Miles/(US) gallon

  • cyl: Number of cylinders

  • disp: Displacement (cu.in.)

  • hp: Gross horsepower

  • drat: Rear axle ratio

  • wt: Weight (1000 lbs)

  • qsec: 1/4 mile time

  • vs: Engine (0 = V-shaped, 1 = straight)

  • am: Transmission (0 = automatic, 1 = manual)

  • gear: Number of forward gear

[46]:
# Load data
data = bmb.load_data('mtcars')
data["cyl"] = data["cyl"].replace({4: "low", 6: "medium", 8: "high"})
data["gear"] = data["gear"].replace({3: "A", 4: "B", 5: "C"})
data["cyl"] = pd.Categorical(data["cyl"], categories=["low", "medium", "high"], ordered=True)

# Define and fit the Bambi model
model = bmb.Model("mpg ~ 0 + hp * wt + cyl + gear", data)
idata = model.fit(draws=1000, target_accept=0.95, random_seed=1234)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [response_sigma, hp, wt, hp:wt, cyl, gear]
100.00% [8000/8000 00:17<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 18 seconds.

We can print the Bambi model object to obtain the model components. Below, we see that the Gaussian linear model uses an identity link function that results in no transformation of the linear predictor to the mean of the outcome variable, and the distrbution of the likelihood is Gaussian.

Now that we have fitted the model, we can visualize how a model parameter varies as a function of (some) interpolated covariate. For this example, we will visualize how the mean response mpg varies as a function of the covariate hp.

The Bambi model, ArviZ inference data object (containing the posterior samples and the data used to fit the model), and a list or dictionary of covariates, in this example only hp, are passed to the plot_cap function. The plot_cap function then computes the conditional adjusted predictions for each covariate in the list or dictionary using the method described above. The plot_cap function returns a matplotlib figure object that can be further customized.

[49]:
fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(model, idata, "hp", ax=ax);
../_images/notebooks_plot_cap_9_0.png

The plot above shows that as hp increases, the mean mpg decreases. As stated above, this insight was obtained by creating the reference grid and then using the fitted model to compute the predicted values of the model parameter, in this example mpg, at each value of the reference grid.

By default, plot_cap uses the highest density interval (HDI) of the posterior distribution to compute the credible interval of the conditional adjusted predictions. The HDI is a Bayesian analog to the frequentist confidence interval. The HDI is the shortest interval that contains a specified probability of the posterior distribution. By default, plot_cap uses the 94% HDI.

plot_cap uses the posterior distribution by default to visualize some mean outcome parameter . However, the posterior predictive distribution can also be plotted by specifying pps=True where pps stands for posterior predictive samples of the response variable.

[37]:
fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(model, idata, "hp", pps=True, ax=ax);
../_images/notebooks_plot_cap_11_0.png

Here, we notice that the uncertainty in the conditional adjusted predictions is much larger than the uncertainty when pps=False. This is because the posterior predictive distribution accounts for the uncertainty in the model parameters and the uncertainty in the data. Whereas, the posterior distribution only accounts for the uncertainty in the model parameters.

plot_cap allows up to three covariates to be plotted simultaneously where the first element in the list represents the main (x-axis) covariate, the second element the group (hue / color), and the third element the facet (panel). However, when plotting more than one covariate, it can be useful to pass specific group and panel arguments to aid in the interpretation of the plot. Therefore, subplot_kwargs allows the user to manipulate the plotting by passing a dictionary where the keys are {"main": ..., "group": ..., "panel": ...} and the values are the names of the covariates to be plotted. For example, passing two covariates hp and wt and specifying subplot_kwargs={"main": "hp", "group": "wt", "panel": "wt"}.

[17]:
plot_cap(
    model=model,
    idata=idata,
    covariates=["hp", "wt"],
    pps=False,
    legend=False,
    subplot_kwargs={"main": "hp", "group": "wt", "panel": "wt"},
    fig_kwargs={"figsize": (20, 8), "sharey": True}
)
plt.tight_layout();
../_images/notebooks_plot_cap_13_0.png

Furthermore, categorical covariates can also be plotted. We plot the the mean mpg as a function of the two categorical covariates gear and cyl below. The plot_cap function automatically plots the conditional adjusted predictions for each level of the categorical covariate. Furthermore, when passing a list of covariates into the plot_cap function, the list will be converted into a dictionary object where the key is taken from (“horizontal”, “color”, “panel”) and the values are the names of the variables. By default, the first element of the list is specified as the “horizontal” covariate, the second element of the list is specified as the “color” covariate, and the third element of the list is mapped to different plot panels.

[40]:
fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(model, idata, ["gear", "cyl"], ax=ax);
../_images/notebooks_plot_cap_15_0.png

Negative Binomial Model#

Lets move onto a model that uses a distribution that is a member of the exponential distribution family and utilizes a link function. For this, we will implement the Negative binomial model from the students absences example. School administrators study the attendance behavior of high school juniors at two schools. Predictors of the number of days of absence include the type of program in which the student is enrolled and a standardized test in math. We have attendance data on 314 high school juniors. The variables of insterest in the dataset are the following:

  • daysabs: The number of days of absence. It is our response variable.

  • progr: The type of program. Can be one of ‘General’, ‘Academic’, or ‘Vocational’.

  • math: Score in a standardized math test.

[4]:
# Load data, define and fit Bambi model
data = pd.read_stata("https://stats.idre.ucla.edu/stat/stata/dae/nb_data.dta")
data["prog"] = data["prog"].map({1: "General", 2: "Academic", 3: "Vocational"})

model_interaction = bmb.Model(
    "daysabs ~ 0 + prog + scale(math) + prog:scale(math)",
    data,
    family="negativebinomial"
)
idata_interaction = model_interaction.fit(
    draws=1000, target_accept=0.95, random_seed=1234, chains=4
)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [daysabs_alpha, prog, scale(math), prog:scale(math)]
100.00% [8000/8000 00:01<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.

This model utilizes a log link function and a negative binomial distribution for the likelihood. Also note that this model also contains an interaction prog:sale(math).

[41]:
model_interaction
[41]:
       Formula: daysabs ~ 0 + prog + scale(math) + prog:scale(math)
        Family: negativebinomial
          Link: mu = log
  Observations: 314
        Priors:
    target = mu
        Common-level effects
            prog ~ Normal(mu: [0. 0. 0.], sigma: [5.0102 7.4983 5.2746])
            scale(math) ~ Normal(mu: 0.0, sigma: 2.5)
            prog:scale(math) ~ Normal(mu: [0. 0.], sigma: [6.1735 4.847 ])

        Auxiliary parameters
            alpha ~ HalfCauchy(beta: 1.0)
------
* To see a plot of the priors call the .plot_priors() method.
* To see a summary or plot of the posterior pass the object returned by .fit() to az.summary() or az.plot_trace()
[42]:
fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(
    model_interaction,
    idata_interaction,
    "math",
    ax=ax,
    pps=False
);
../_images/notebooks_plot_cap_20_0.png

The plot above shows that as math increases, the mean daysabs decreases. However, as the model contains an interaction term, the effect of math on daysabs depends on the value of prog. Therefore, we will use plot_cap to plot the conditional adjusted predictions for each level of prog.

[5]:
fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(
    model_interaction,
    idata_interaction,
    ["math", "prog"],
    ax=ax,
    pps=False
);
../_images/notebooks_plot_cap_22_0.png

Passing specific subplot_kwargs can allow for a more interpretable plot. Especially when the posterior predictive distribution plot results in overlapping credible intervals.

[47]:
plot_cap(
    model_interaction,
    idata_interaction,
    covariates=["math", "prog"],
    pps=True,
    subplot_kwargs={"main": "math", "group": "prog", "panel": "prog"},
    legend=False,
    fig_kwargs={"figsize": (16, 5), "sharey": True}
);
../_images/notebooks_plot_cap_24_0.png

Logistic Regression#

To further demonstrate the plot_cap function, we will implement a logistic regression model. This example is taken from the marginaleffects plot_predictions documentation. The internet movie database, http://imdb.com/, is a website devoted to collecting movie data supplied by studios and fans. It claims to be the biggest movie database on the web and is run by Amazon. The movies in this dataset were selected for inclusion if they had a known length and had been rated by at least one imdb user. The dataset below contains 28,819 rows and 24 columns. The variables of interest in the dataset are the following: - title. Title of the movie. - year. Year of release. - budget. Total budget (if known) in US dollars - length. Length in minutes. - rating. Average IMDB user rating. - votes. Number of IMDB users who rated this movie. - r1-10. Multiplying by ten gives percentile (to nearest 10%) of users who rated this movie a 1. - mpaa. MPAA rating. - action, animation, comedy, drama, documentary, romance, short. Binary variables represent- ing if movie was classified as belonging to that genre.

[70]:
data = pd.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/ggplot2movies/movies.csv")

data["style"] = "Other"
data.loc[data["Action"] == 1, "style"] = "Action"
data.loc[data["Comedy"] == 1, "style"] = "Comedy"
data.loc[data["Drama"] == 1, "style"] = "Drama"
data["certified_fresh"] = (data["rating"] >= 8) * 1
data = data[data["length"] < 240]

priors = {"style": bmb.Prior("Normal", mu=0, sigma=2)}
model = bmb.Model("certified_fresh ~ 0 + length * style", data=data, priors=priors, family="bernoulli")
idata = model.fit(random_seed=1234, target_accept=0.9, init="adapt_diag")
Modeling the probability that certified_fresh==1
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [length, style, length:style]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 421 seconds.

The logistic regression model uses a logit link function and a Bernoulli likelihood. Therefore, the link scale is the log-odds of a successful response and the response scale is the probability of a successful response.

[73]:
model
[73]:
       Formula: certified_fresh ~ 0 + length * style
        Family: bernoulli
          Link: p = logit
  Observations: 58662
        Priors:
    target = p
        Common-level effects
            length ~ Normal(mu: 0.0, sigma: 0.0708)
            style ~ Normal(mu: 0.0, sigma: 2.0)
            length:style ~ Normal(mu: [0. 0. 0.], sigma: [0.0702 0.0509 0.0611])
------
* To see a plot of the priors call the .plot_priors() method.
* To see a summary or plot of the posterior pass the object returned by .fit() to az.summary() or az.plot_trace()

Again, by default, the plot_cap function plots the mean outcome on the response scale. Therefore, the plot below shows the probability of a successful response certified_fresh as a function of length.

[72]:
fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(model, idata, "length", ax=ax);
../_images/notebooks_plot_cap_30_0.png

Additionally, we can see how the probability of certified_fresh varies as a function of categorical covariates.

[79]:
fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(model, idata, "style", ax=ax);
../_images/notebooks_plot_cap_32_0.png

Plotting other model parameters#

plot_cap also has the argument target where target determines what parameter of the response distribution is plotted as a function of the explanatory variables. This argument is useful in distributional models, i.e., when the response distribution contains a parameter for location, scale and or shape. The default of this argument is mean and passing a parameter into target only works when the argument pps=False because when pps=True the posterior predictive distribution is plotted and thus, can only refer to the outcome variable (instead of any of the parameters of the response distribution). For this example, we will simulate our own dataset.

[51]:
rng = np.random.default_rng(121195)
N = 200
a, b = 0.5, 1.1
x = rng.uniform(-1.5, 1.5, N)
shape = np.exp(0.3 + x * 0.5 + rng.normal(scale=0.1, size=N))
y = rng.gamma(shape, np.exp(a + b * x) / shape, N)
data_gamma = pd.DataFrame({"x": x, "y": y})

formula = bmb.Formula("y ~ x", "alpha ~ x")
model = bmb.Model(formula, data_gamma, family="gamma")
idata = model.fit(random_seed=1234)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [Intercept, x, alpha_Intercept, alpha_x]
100.00% [8000/8000 00:02<00:00 Sampling 4 chains, 25 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.
There were 25 divergences after tuning. Increase `target_accept` or reparameterize.
[52]:
model
[52]:
       Formula: y ~ x
                alpha ~ x
        Family: gamma
          Link: mu = inverse
                alpha = log
  Observations: 200
        Priors:
    target = mu
        Common-level effects
            Intercept ~ Normal(mu: 0.0, sigma: 2.5037)
            x ~ Normal(mu: 0.0, sigma: 2.8025)
    target = alpha
        Common-level effects
            alpha_Intercept ~ Normal(mu: 0.0, sigma: 1.0)
            alpha_x ~ Normal(mu: 0.0, sigma: 1.0)
------
* To see a plot of the priors call the .plot_priors() method.
* To see a summary or plot of the posterior pass the object returned by .fit() to az.summary() or az.plot_trace()

The model we defined uses a gamma distribution parameterized by alpha and mu where alpha utilizes a log link and mu goes through an inverse link. Therefore, we can plot either: (1) the mu of the response distribution (which is the default), or (2) alpha of the response distribution as a function of the explanatory variable \(x\).

[53]:
# First, the mean of the response (default)
fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(model, idata, "x", ax=ax);
../_images/notebooks_plot_cap_37_0.png

Below, instead of plotting the default target, target=mean, we set target=alpha to visualize how the model parameter alpha varies as a function of the x predictor.

[54]:
# Second, another param. of the distribution: alpha
fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(model, idata, "x", target='alpha', ax=ax);
../_images/notebooks_plot_cap_39_0.png
[55]:
%load_ext watermark
%watermark -n -u -v -iv -w
The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
Last updated: Sat Jun 24 2023

Python implementation: CPython
Python version       : 3.11.0
IPython version      : 8.13.2

numpy     : 1.24.2
bambi     : 0.10.0.dev0
pandas    : 2.0.1
matplotlib: 3.7.1
arviz     : 0.15.1

Watermark: 2.3.1