Module aliases imported by init_notebook.py:
--------------------------------------------
from cmdstanpy import CmdStanModel
import bridgestan as bs
import numpy as np
import pandas as pd
import arviz as az
import utils as utils
import scipy.stats as stats
import matplotlib.pyplot as plt
import seaborn as sns
Watermark:
----------
Last updated: 2025-03-17T16:25:12.321376+04:00
Python implementation: CPython
Python version : 3.13.2
IPython version : 9.0.2
Compiler : Clang 16.0.0 (clang-1600.0.26.6)
OS : Darwin
Release : 24.4.0
Machine : arm64
Processor : arm
CPU cores : 8
Architecture: 64bit
numpy : 2.2.4
matplotlib: 3.10.1
bridgestan: 2.6.1
cmdstanpy : 1.2.5
pandas : 2.2.3
arviz : 0.21.0
re : 2.2.1
seaborn : 0.13.2
watermark : 2.5.0
scipy : 1.15.2
1 Introduction
One of the most reliable sources of waffles in North America, if not the entire world, is a Waffle House diner. Waffle House is nearly always open, even just after a hurricane. Most diners invest in disaster preparedness, including having their own electrical generators. As a consequence, the United States’ disaster relief agency (FEMA) informally uses Waffle House as an index of disaster severity. If the Waffle House is closed, that’s a serious event.
It is ironic then that steadfast Waffle House is associated with the nation’s highest divorce rates. States with many Waffle Houses per person, like Georgia and Alabama, also have some of the highest divorce rates in the United States. The lowest divorce rates are found where there are zero Waffle Houses. Could always-available waffles and hash brown potatoes put marriage at risk?
Probably not. This is an example of a misleading correlation. No one thinks there is any plausible mechanism by which Waffle House diners make divorce more likely. Instead, when we see a correlation of this kind, we immediately start asking about other variables that are really driving the relationship between waffles and divorce. In this case, Waffle House began in Georgia in the year 1955. Over time, the diners spread across the Southern United States, remaining largely within it. So Waffle House is associated with the South. Divorce is not a uniquely Southern institution, but the Southern United States has some of the highest divorce rates in the nation. So it’s probably just an accident of history that Waffle House and high divorce rates both occur in the South.
Such accidents are commonplace. It is not surprising that Waffle House is correlated with divorce, because correlation in general is not surprising. In large data sets, every pair of variables has a statistically discernible non-zero correlation. But since most correlations do not indicate causal relationships, we need tools for distinguishing mere association from evidence of causation. This is why so much effort is devoted to multiple regression, using more than one predictor variable to simultaneously model an outcome. Reasons given for multiple regression models include:
Statistical “control” for confounds. A confound is something that misleads us about a causal influence—there will be a more precise definition in the next chapter. The spurious waffles and divorce correlation is one type of confound, where southern-ness makes a variable with no real importance (Waffle House density) appear to be important. But confounds are diverse. They can hide important effects just as easily as they can produce false ones.
Multiple and complex causation. A phenomenon may arise from multiple simultaneous causes, and causes can cascade in complex ways. And since one cause can hide another, they must be measured simultaneously.
Interactions. The importance of one variable may depend upon another. For example, plants benefit from both light and water. But in the absence of either, the other is no benefit at all. Such interactions occur very often. Effective inference about one variable will often depend upon consideration of others.
In this chapter, we begin to deal with the first of these two, using multiple regression to deal with simple confounds and to take multiple measurements of association. You’ll see how to include any arbitrary number of main effects in your linear model of the Gaussian mean. These main effects are additive combinations of variables, the simplest type of multiple variable model. We’ll focus on two valuable things these models can help us with: (1) revealing spurious correlations like the Waffle House correlation with divorce and (2) revealing important correlations that may be masked by unrevealed correlations with other variables. Along the way, you’ll meet categorical variables, which require special handling compared to continuous variables.
However, multiple regression can be worse than useless, if we don’t know how to use it. Just adding variables to a model can do a lot of damage. In this chapter, we’ll begin to think formally about causal inference and introduce graphical causal models as a way to design and interpret regression models. The next chapter continues on this theme, describing some serious and common dangers of adding predictor variables, ending with a unifying framework for understanding the examples in both this chapter and the next.
Causal Inference
Despite its central importance, there is no unified approach to causal inference yet in the sciences. There are even people who argue that cause does not really exist; it’s just a psychological illusion. And in complex dynamical systems, everything seems to cause everything else. “Cause” loses intuitive value. About one thing, however, there is general agreement: Causal inference always depends upon unverifiable assumptions. Another way to say this is that it’s always possible to imagine some way in which your inference about cause is mistaken, no matter how careful the design or analysis. A lot can be accomplished, despite this barrier.
2 Spurious Association
Let’s leave waffles behind, at least for the moment. An example that is easier to understand is the correlation between divorce rate and marriage rate. The rate at which adults marry is a great predictor of divorce rate, as seen in the plots below. But does marriage cause divorce? In a trivial sense it obviously does: One cannot get a divorce without first getting married. But there’s no reason high marriage rate must cause more divorce. It’s easy to imagine high marriage rate indicating high cultural valuation of marriage and therefore being associated with low divorce rate.
Another predictor associated with divorce is the median age at marriage, also displayed in a plot below. Age at marriage is also a good predictor of divorce rate — higher age at marriage predicts less divorce. But there is no reason this has to be causal, either, unless age at marriage is very late and the spouses do not live long enough to get a divorce.
Let’s load these data and standardize the variables of interest:
\(D_i\) is the standardized (zero centered, standard deviation one) divorce rate for State \(i\), and \(A_i\) is State \(i\)’s standardized median age at marriage.
What about those priors? Since the outcome and the predictor are both standardized, the intercept \(α\) should end up very close to zero. What does the prior slope \(β_A\) imply? If \(β_A =1\), that would imply that a change of one standard deviation in age at marriage is associated likewise with a change of one standard deviation in divorce. To know whether or not that is a strong relationship, you need to know how big a standard deviation of age at marriage is:
d.MedianAgeMarriage.std()
np.float64(1.2436303013880823)
So when \(β_A =1\), a change of 1.2 years in median age at marriage is associated with a full standard deviation change in the outcome variable. That seems like an insanely strong relationship. The prior above thinks that only 5% of plausible slopes are more extreme than 1. We’ll simulate from these priors in a moment, so you can see how they look in the outcome space.
To compute the approximate posterior, there are no new code tricks or techniques here. But I’ll add comments to help explain the mass of code to follow.
m5_1 ='''data { int<lower=0> N; vector[N] A; vector[N] D; int<lower=0> N_tilde; vector[N_tilde] A_seq;}parameters { real a; real bA; real<lower=0> sigma;}transformed parameters { // Linear Model vector[N] mu; mu = a + bA * A; // Posterior Predictive Sampling vector[N_tilde] mu_tilde; mu_tilde = a + bA * A_seq;}model { D ~ normal(mu, sigma); a ~ normal(0, 0.2); bA ~ normal(0, 0.5); sigma ~ exponential(1);}generated quantities { // Prior Predictive Simulation real a_sim = normal_rng(0, 0.2); real bA_sim = normal_rng(0, 0.5); vector[N_tilde] mu_sim = a_sim + bA_sim * A_seq; // Posterior Predictive Sampling - y_tilde or yhat vector[N_tilde] y_tilde; for (i in 1:N_tilde) { y_tilde[i] = normal_rng(mu_tilde[i], sigma); }}'''A_seq = np.linspace(-3, 3.2, 100)data = {'N': len(d),'A': d.A.tolist(),'D': d.D.tolist(),'N_tilde': len(A_seq),'A_seq': A_seq.tolist()}m5_1_model = utils.StanQuap('stan_models/m5_1', m5_1, data=data, generated_var=['y_tilde', 'a_sim', 'bA_sim', 'mu_sim', 'mu_tilde', 'mu'])
To simulate from the priors, I have have set up set up the generated quantities block to generate a_sim and bA_sim from the defined priors and compute the \(\mu_{sim}\). I’ll plot the lines over the range of 2 standard deviations for both the outcome and predictor. That’ll cover most of the possible range of both variables.
def plot_priors(): mu_sim = m5_1_model.laplace_sample(draws=100).stan_variable('mu_sim') plt.plot(A_seq, mu_sim.T, 'k', alpha=0.2) plt.xlim(-2,2) plt.xlabel('Median Age Marriage (std)') plt.ylabel('Divorce Rate (std)')plt.clf()plot_priors()
Plausible regression lines implied by the priors in m5_1. These are weakly informative priors in that they allow some implausibly strong relationships but generally bound the lines to possible ranges of the variables.
You may wish to try some vaguer, flatter priors and see how quickly the prior regression lines become ridiculous.
Now for the posterior predictions. The procedure is exactly like earlier: either use link method to compute and summarize the mu from the extracted parameter samples or build mu_tilde and y_tilde (or in some conventions mu_hat and y_hat) directly into the Stan Model and extract values without further computation. Then we summarize \(\mu\) with its mean and credible interval.
As you can see in the figure, this relationship isn’t as strong as the previous one.
Divorce rate is associated with both marriage rate and median age at marriage. Both predictor variables are standardized in this example. The average marriage rate across States is 20 per 1000 adults, and the average median age at marriage is 26 years.
But merely comparing parameter means between different bivariate regressions is no way to decide which predictor is better. Both of these predictors could provide independent value, or they could be redundant, or one could eliminate the value of the other.
To make sense of this, we’re going to have to think causally. And then, only after we’ve done some thinking, a bigger regression model that includes both age at marriage and marriage rate will help us.
2.1 Think Before you Regress
There are three observed variables in play: divorce rate (D), marriage rate (M), and the median age at marriage (A) in each State. The pattern we see in the previous two models and illustrated in the plots above is symptomatic of a situation in which only one of the predictor variables, A in this case, has a causal impact on the outcome, D, even though both predictor variables are strongly associated with the outcome.
To understand this better, it is helpful to introduce a particular type of causal graph known as a DAG, short for directed acyclic graph. Graph means it is nodes and connections. Directed means the connections have arrows that indicate directions of causal influence. And acyclic means that causes do not eventually flow back on themselves. A DAG is a way of describing qualitative causal relationships among variables. It isn’t as detailed as a full model description, but it contains information that a purely statistical model does not. Unlike a statistical model, a DAG will tell you the consequences of intervening to change a variable. But only if the DAG is correct. There is no inference without assumption.
The full framework for using DAGs to design and critique statistical models is complicated. So instead of smothering you in the whole framework right now, I’ll build it up one example at a time. By the end of the next chapter, you’ll have a set of simple rules that let you accomplish quite a lot of criticism. And then other applications will be introduced in later chapters.
Let’s start with the basics. Here is a possible DAG for our divorce rate example:
It may not look like much, but this type of diagram does a lot of work. It represents a heuristic causal model. Like other models, it is an analytical assumption. The symbols \(A\), \(M\), and \(D\) are our observed variables. The arrows show directions of influence. What this DAG says is:
\(A\) directly influences \(D\)
\(M\) directly influences \(D\)
\(A\) directly influences \(M\)
These statements can then have further implications. In this case, age of marriage influences divorce in two ways. First it has a direct effect, \(A → D\). Perhaps a direct effect would arise because younger people change faster than older people and are therefore more likely to grow incompatible with a partner. Second, it has an indirect effect by influencing the marriage rate, which then influences divorce, \(A → M → D\). If people get married earlier, then the marriage rate may rise, because there are more young people. Consider for example if an evil dictator forced everyone to marry at age 65. Since a smaller fraction of the population lives to 65 than to 25, forcing delayed marriage will also reduce the marriage rate. If marriage rate itself has any direct effect on divorce, maybe by making marriage more or less normative, then some of that direct effect could be the indirect effect of age at marriage.
To infer the strength of these different arrows, we need more than one statistical model. Model m5_1, the regression of \(D\) on \(A\), tells us only that the total influence of age at marriage is strongly negative with divorce rate. The “total” here means we have to account for every path from A to D. There are two such paths in this graph: \(A → D\), a direct path, and \(A → M → D\), an indirect path. In general, it is possible that a variable like \(A\) has no direct effect at all on an outcome like \(D\). It could still be associated with \(D\) entirely through the indirect path. That type of relationship is known as mediation, and we’ll have another example later.
As you’ll see however, the indirect path does almost no work in this case. How can we show that? We know from m5_2 that marriage rate is positively associated with divorce rate. But that isn’t enough to tell us that the path \(M → D\) is positive. It could be that the association between \(M\) and \(D\) arises entirely from \(A\)’s influence on both \(M\) and \(D\). Like this:
This DAG is also consistent with the posterior distributions of models m5_1 and m5_2. Why? Because both \(M\) and \(D\) “listen” to \(A\). They have information from \(A\). So when you inspect the association between \(D\) and \(M\), you pick up that common information that they both got from listening to \(A\). You’ll see a more formal way to deduce this later.
So which is it? Is there a direct effect of marriage rate, or rather is age at marriage just driving both, creating a spurious correlation between marriage rate and divorce rate? To find out, we need to consider carefully what each DAG implies.
What’s a cause?
Questions of causation can become bogged down in philosophical debates. These debates are worth having. But they don’t usually intersect with statistical concerns. Knowing a cause in statistics means being able to correctly predict the consequences of an intervention. There are contexts in which even this is complicated. For example, it isn’t possible to directly change someone’s body weight. Changing someone’s body weight would mean intervening on another variable, like diet, and that variable would have other causal effects in addition. But being underweight can still be a legitimate cause of disease, even when we can’t intervene on it directly.
2.2 Testable Implications
How do we use data to compare multiple, plausible causal models? The first thing to consider is the testable implications of each model. Consider the two DAGs we have so far considered:
plt.clf()pgm1.render()plt.gca().invert_yaxis()
<Figure size 500x400 with 0 Axes>
plt.clf()pgm2.render()plt.gca().invert_yaxis()
<Figure size 500x400 with 0 Axes>
Any DAG may imply that some variables are independent of others under certain conditions. These are the model’s testable implications, its conditional independencies. Conditional independencies come in two forms. First, they are statements of which variables should be associated with one another (or not) in the data. Second, they are statements of which variables become dis-associated when we condition on some other set of variables.
What does “conditioning” mean? Informally, conditioning on a variable Z means learning its value and then asking if \(X\) adds any additional information about Y. If learning \(X\) doesn’t give you any more information about \(Y\), then we might say that \(Y\) is independent of \(X\) conditional on \(Z\). This conditioning statement is sometimes written as: \(Y \perp\!\!\!\perp X \mid Z\). This is very weird notation and any feelings of annoyance on your part are justified. We’ll work with this concept a lot, so don’t worry if it doesn’t entirely make sense right now. You’ll see examples very soon.
Let’s consider conditional independence in the context of the divorce example. What are the conditional independencies of the DAGs at the top? How do we derive these conditional independencies? Finding conditional independencies is not hard, but also not at all obvious. With a little practice, it becomes very easy. The more general rules can wait until later. For now, let’s consider each DAG in turn and inspect the possibilities.
For the DAG on the left above, the one with three arrows, first note that every pair of variables is correlated. This is because there is a causal arrow between every pair. These arrows create correlations. So before we condition on anything, everything is associated with everything else. This is already a testable implication. We could write it:
\[
D \not\!\perp\!\!\!\perp A \qquad D \not\!\perp\!\!\!\perp M \qquad A \not\!\perp\!\!\!\perp M
\]
That \(\not\!\perp\!\!\!\perp\) thing means “not independent of”. If we now look in the data and find that any pair of variables are not associated, then something is wrong with the DAG (assuming the data are correct). In these data, all three pairs are in fact strongly associated. Check for yourself. You can use cor to measure simple correlations. Correlations are sometimes terrible measures of association — many different patterns of association with different implications can produce the same correlation. But they do honest work in this case.
Are there any other testable implications for the first DAG above? No. It will be easier to see why, if we slide over to consider the second DAG, the one in which \(M\) has no influence on \(D\). In this DAG, it is still true that all three variables are associated with one another. \(A\) is associated with \(D\) and \(M\) because it influences them both. And \(D\) and \(M\) are associated with one another, because \(M\) influences them both. They share a cause, and this leads them to be correlated with one another through that cause. But suppose we condition on \(A\). All of the information in M that is relevant to predicting \(D\) is in \(A\). So once we’ve conditioned on \(A\), \(M\) tells us nothing more about \(D\). So in the second DAG, a testable implication is that \(D\) is independent of \(M\), conditional on \(A\). In other words, \(D \perp\!\!\!\perp M \mid A\). The same thing does not happen with the first DAG. Conditioning on \(A\) does not make \(D\) independent of \(M\), because \(M\) really influences \(D\) all by itself in this model.
Later, we’ll see the general rules for deducing these implications. For now, the daft & CausalGraphicalModel packages have the rules built in and can find the implications for you. Here’s the code to define the second DAG and display the implied conditional independencies.
print('DAG 1:', dag5_1.get_distribution())print('DAG 2:', dag5_2.get_distribution())# get_all_independence_relationships() method # Returns a list of all pairwise conditional independence relationships # implied by the graph structure.dag5_2.get_all_independence_relationships()# [('M', 'D', {'A'})] means D _||_ M | A
DAG 1: P(A)P(M|A)P(D|A,M)
DAG 2: P(A)P(D|A)P(M|A)
[('M', 'D', {'A'})]
The first DAG has no conditional independencies.
dag5_1.get_all_independence_relationships()
[]
There are no conditional independencies, so there is no output to display.
Let’s try to summarize. The testable implications of the first DAG are that all pairs of variables should be associated, whatever we condition on. The testable implications of the second DAG are that all pairs of variables should be associated, before conditioning on anything, but that \(D\) and \(M\) should be independent after conditioning on \(A\). So the only implication that differs between these DAGs is the last one: \(D \perp\!\!\!\perp M \mid A\).
To test this implication, we need a statistical model that conditions on \(A\), so we can see whether that renders \(D\) independent of \(M\). And that is what multiple regression helps with. It can address a useful descriptive question:
Is there any additional value in knowing a variable, once I already know all of the other predictor variables?
So for example once you fit a multiple regression to predict divorce using both marriage rate and age at marriage, the model addresses the questions:
After I already know marriage rate, what additional value is there in also knowing age at marriage?
After I already know age at marriage, what additional value is there in also knowing marriage rate?
The parameter estimates corresponding to each predictor are the (often opaque) answers to these questions. The questions above are descriptive, and the answers are also descriptive. It is only the derivation of the testable implications above that gives these descriptive results a causal meaning. But that meaning is still dependent upon believing the DAG.
“Control” is out of control
Very often, the question just above is spoken of as “statistical control,” as in controlling for the effect of one variable while estimating the effect of another. But this is sloppy language, as it implies too much. Statistical control is quite different from experimental control, as we’ll explore more in the later. The point here isn’t to police language. Instead, the point is to observe the distinction between small world and large world interpretations. Since most people who use statistics are not statisticians, sloppy language like “control” can promote a sloppy culture of interpretation. Such cultures tend to overestimate the power of statistical methods, so resisting them can be difficult. Disciplining your own language may be enough. Disciplining another’s language is hard to do, without seeming like a fastidious scold, as this very box must seem.
2.3 Multiple Regression Notation
Multiple regression formulas look a lot like the polynomial models, they add more parameters and variables to the definition of \(μ_i\). The strategy is straightforward:
Nominate the predictor variables you want in the linear model of the mean.
For each predictor, make a parameter that will measure its conditional association with the outcome.
Multiply the parameter by the variable and add that term to the linear model.
Here is the model that predicts divorce rate, using both marriage rate and age at marriage.
You can use whatever symbols you like for the parameters and variables, but here I’ve chosen \(M\) for marriage rate and \(A\) for age at marriage, reusing these symbols as subscripts for the corresponding parameters. But feel free to use whichever symbols reduce the load on your own memory.
So what does it mean to assume \(\mu_i \sim \alpha + \beta_M M_i + \beta_A A_i\)? Mechanically, it means that the expected outcome for any State with marriage rate \(M_i\) and median age at marriage \(A_i\) is the sum of three independent terms. If you are like most people, this is still pretty mysterious. The mechanical meaning of the equation doesn’t map onto a unique causal meaning. Let’s take care of the mechanical bits first, before returning to interpretation.
Compact notation and the design matrix
Often, linear models are written using a compact form like:
where \(j\) is an index over predictor variables and \(n\) is the number of predictor variables. This may be read as the mean is modeled as the sum of an intercept and an additive combination of the products of parameters and predictors. Even more compactly, using matrix notation:
\[
\mathbf{m} = \mathbf{Xb}
\]
where \(\mathbf{m}\) is a vector of predicted means, one for each row in the data, \(\mathbf{b}\) is a (column) vector of parameters, one for each predictor variable (including the intercept), and \(\mathbf{X}\) is a matrix. This matrix is called a design matrix. It has as many rows as the data, and as many columns as there are predictors plus one. So \(\mathbf{X}\) is basically a data frame (or 2D Array), but with an extra first column. The extra column is filled with 1s. These 1s are multiplied by the first parameter, which is the intercept, and so return the unmodified intercept. When \(\mathbf{X}\) is matrix-multiplied by \(\mathbf{b}\), you get the predicted means. In Python notation, this operation is X @ b. (ensuring that X is of shape [rows, cols] and b is of shape shape [rows, 1], b should be reshaped if it is a row vector)
m5_dmat ='''data { int<lower=1> N; // number of observations int<lower=1> K; // number of regressors (including constant) vector[N] D; // outcome matrix[N, K] X; // regressors}parameters { real<lower=0> sigma; // scale vector[K] b; // coefficients (including constant)}transformed parameters { vector[N] mu = X * b; // location}model { D ~ normal(mu, sigma); // probability model sigma ~ exponential(1); // prior for scale b[1] ~ normal(0, 0.2); // prior for intercept for (i in 2:K) { // priors for coefficients b[i] ~ normal(0, 0.5); }}generated quantities { vector[N] y_tilde; // predicted outcome for (i in 1:N) y_tilde[i] = normal_rng(mu[i], sigma);}'''m5_dmat_model = utils.Stan('stan_models/m5_dmat', m5_dmat)
We’re will also use the design matrix approach (as we have done earlier). It’s good to recognize it, and sometimes it can save you a lot of work. For example, for linear regressions, there is a nice matrix formula for the maximum likelihood (or least squares) estimates. Most statistical software exploits that formula.
And here is the code to approximate the posterior distribution:
m5_3 ='''data { int<lower=0> N; vector[N] M; vector[N] D; vector[N] A;}parameters { real a; real bA; real bM; real<lower=0> sigma;}transformed parameters { vector[N] mu; mu = a + bA * A + bM * M;}model { D ~ normal(mu, sigma); a ~ normal(0, 0.2); bA ~ normal(0, 0.5); bM ~ normal(0, 0.5); sigma ~ exponential(1);}generated quantities { // Posterior Predictive Check - y_rep (replications) vector[N] y_rep; for (i in 1:N) { y_rep[i] = normal_rng(mu[i], sigma); }}'''M_seq = np.linspace(-2, 3, 100)data = {'N': len(d),'A': d.A.tolist(),'M': d.M.tolist(),'D': d.D.tolist()}m5_3_model = utils.StanQuap('stan_models/m5_3', m5_3, data=data, generated_var=['y_rep', 'mu'])m5_3_model.precis().round(2)
Mean
StDev
5.5%
94.5%
Parameter
a
0.00
0.10
-0.16
0.16
bA
-0.61
0.15
-0.85
-0.37
bM
-0.07
0.15
-0.31
0.18
sigma
0.79
0.08
0.66
0.91
The posterior mean for marriage rate, bM, is now close to zero, with plenty of probability of both sides of zero. The posterior mean for age at marriage, bA, is essentially unchanged. It will help to visualize the posterior distributions for all three models, focusing just on the slope parameters \(β_A\) and \(β_M\):
The posterior means are shown by the points and the 89% compatibility intervals by the solid horizontal lines. Notice how bA doesn’t move, only grows a bit more uncertain, while bM is only associated with divorce when age at marriage is missing from the model. You can interpret these distributions as saying: Once we know median age at marriage for a State, there is little or no additional predictive power in also knowing the rate of marriage in that State.
In that weird notation, \(D \perp\!\!\!\perp M \mid A\). This tests the implication of the second DAG from earlier. Since the first DAG did not imply this result, it is out.
Note that this does not mean that there is no value in knowing marriage rate. Consistent with the earlier DAG, if you didn’t have access to age-at-marriage data, then you’d definitely find value in knowing the marriage rate. \(M\) is predictive but not causal. Assuming there are no other causal variables missing from the model (more on that later), this implies there is no important direct causal path from marriage rate to divorce rate. The association between marriage rate and divorce rate is spurious, caused by the influence of age of marriage on both marriage rate and divorce rate. I’ll leave it to the reader to investigate the relationship between age at marriage, \(A\), and marriage rate, \(M\), to complete the picture.
But how did model m5_3 achieve the inference that marriage rate adds no additional information, once we know age at marriage? Let’s draw some pictures.
Simulating the divorce example
The divorce data are real data. But it is useful to simulate the kind of causal relationships shown in the previous DAG: \(M ← A → D\). Every DAG implies a simulation, and such simulations can help us design models to correctly infer relationships among variables. In this case, you just need to simulate each of the three variables:
N =50age = stats.norm().rvs(N) # sim Amar = stats.norm(-age).rvs(N) # sim A -> Mdiv = stats.norm(age).rvs(N) # sim A -> D
Now if you use these variables in models m5_1, m5_2, and m5_3, you’ll see the same pattern of posterior inferences. It is also possible to simulate that both \(A\) and \(M\) influence \(D\): div = stats.norm(age + mar).rvs(N). In that case, a naive regression of \(D\) on \(A\) will overestimate the influence of \(A\), just like a naive regression of \(D\) on \(M\) will overestimate the importance of \(M\). The multiple regression will help sort things out for you in this situation as well. But interpreting the parameter estimates will always depend upon what you believe about the causal model, because typically several (or very many) causal models are consistent with any one set of parameter estimates. We’ll discuss this later as Markov equivalence.
2.5 Plotting Multivariate Posteriors
Let’s pause for a moment, before moving on. There are a lot of moving parts here: three variables, some strange DAGs, and three models. If you feel at all confused, it is only because you are paying attention.
It will help to visualize the model’s inferences. Visualizing the posterior distribution in simple bivariate regressions, like those in the earlier, is easy. There’s only one predictor variable, so a single scatterplot can convey a lot of information. And so earlier we used scatters of the data. Then we overlaid regression lines and intervals to both (1) visualize the size of the association between the predictor and outcome and (2) to get a crude sense of the ability of the model to predict the individual observations.
With multivariate regression, you’ll need more plots. There is a huge literature detailing a variety of plotting techniques that all attempt to help one understand multiple linear regression. None of these techniques is suitable for all jobs, and most do not generalize beyond linear regression. So the approach I take here is to instead help you compute whatever you need from the model. I offer three examples of interpretive plots:
Predictor residual plots. These plots show the outcome against residual predictor values. They are useful for understanding the statistical model, but not much else.
Posterior prediction plots. These show model-based predictions against raw data, or otherwise display the error in prediction. They are tools for checking fit and assessing predictions. They are not causal tools.
Counterfactual plots. These show the implied predictions for imaginary experiments. These plots allow you to explore the causal implications of manipulating one or more variables.
Each of these plot types has its advantages and deficiencies, depending upon the context and the question of interest. In the rest of this section, I show you how to manufacture each of these in the context of the divorce data.
Predictor Residual Plots
A predictor residual is the average prediction error when we use all of the other predictor variables to model a predictor of interest. That’s a complicated concept, so we’ll go straight to the example, where it will make sense. The benefit of computing these things is that, once plotted against the outcome, we have a bivariate regression that has already conditioned on all of the other predictor variables. It leaves the variation that is not expected by the model of the mean, \(μ\), as a function of the other predictors.
In our model of divorce rate, we have two predictors: (1) marriage rate \(M\) and (2) median age at marriage \(A\). To compute predictor residuals for either, we just use the other predictor to model it. So for marriage rate, this is the model we need:
As before, \(M\) is marriage rate and \(A\) is median age at marriage. Note that since we standardized both variables, we already expect the mean \(α\) to be around zero, as before. So I’m reusing the same priors as earlier. This code will approximate the posterior:
# A -> Mm5_4 ='''data { int<lower=0> N; vector[N] M; vector[N] A;}parameters { real a; real bAM; real<lower=0> sigma;}transformed parameters { vector[N] mu; mu = a + bAM * A;}model { M ~ normal(mu, sigma); a ~ normal(0, 0.2); bAM ~ normal(0, 0.5); sigma ~ exponential(1);}'''# M -> Am5_4a ='''data { int<lower=0> N; vector[N] M; vector[N] A;}parameters { real a; real bMA; real<lower=0> sigma;}transformed parameters { vector[N] mu; mu = a + bMA * M;}model { A ~ normal(mu, sigma); a ~ normal(0, 0.2); bMA ~ normal(0, 0.5); sigma ~ exponential(1);}'''data = {'N': len(d),'A': d.A.tolist(),'M': d.M.tolist()}m5_4_model = utils.StanQuap('stan_models/m5_4', m5_4, data=data, generated_var=['y_rep', 'mu'])m5_4a_model = utils.StanQuap('stan_models/m5_4a', m5_4a, data=data, generated_var=['y_rep', 'mu'])
And then we compute the residuals by subtracting the observed marriage rate in each State from the predicted rate, based upon the model above:
Understanding multiple regression through residuals. The top row shows each predictor regressed on the other predictor. The lengths of the line segments connecting the model’s expected value of the outcome, the regression line, and the actual value are the residuals. In the bottom row, divorce rate is regressed on the residuals from the top row. Bottom left: Residual variation in marriage rate shows little association with divorce rate. Bottom right: Divorce rate on age at marriage residuals, showing remaining variation, and this variation is associated with divorce rate.
When a residual is positive, that means that the observed rate was in excess of what the model expects, given the median age at marriage in that State. When a residual is negative, that means the observed rate was below what the model expects. In simpler terms, States with positive residuals have high marriage rates for their median age of marriage, while States with negative residuals have low rates for their median age of marriage. It’ll help to plot the relationship between these two variables, and show the residuals as well. In the plots above, we show m5_4 along with line segments for each residual. Notice that the residuals are variation in marriage rate that is left over, after taking out the purely linear relationship between the two variables.
Now to use these residuals, let’s put them on a horizontal axis and plot them against the actual outcome of interest, divorce rate. In the plot above, we also plot these residuals against divorce rate, overlaying the linear regression of the two variables. You can think of this plot as displaying the linear relationship between divorce and marriage rates, having conditioned already on median age of marriage. The vertical dashed line indicates marriage rate that exactly matches the expectation from median age at marriage. So States to the right of the line have higher marriage rates than expected. States to the left of the line have lower rates. Average divorce rate on both sides of the line is about the same, and so the regression line demonstrates little relationship between divorce and marriage rates.
The same procedure works for the other predictor. The plot above shows the regression of \(A\) on \(M\) and the residuals. In the plot below, these residuals are used to predict divorce rate. States to the right of the vertical dashed line have older-than-expected median age at marriage, while those to the left have younger-than-expected median age at marriage. Now we find that the average divorce rate on the right is lower than the rate on the left, as indicated by the regression line. States in which people marry older than expected for a given rate of marriage tend to have less divorce.
So what’s the point of all of this? There’s conceptual value in seeing the model-based predictions displayed against the outcome, after subtracting out the influence of other predictors. The plots in this section do this. But this procedure also brings home the message that regression models measure the remaining association of each predictor with the outcome, after already knowing the other predictors. In computing the predictor residual plots, we had to perform those calculations ourselves. In the unified multivariate model, it all happens automatically. Nevertheless, it is useful to keep this fact in mind, because regressions can behave in surprising ways as a result. We’ll have an example soon.
Linear regression models do all of this simultaneous measurement with a very specific additive model of how the variables relate to one another. But predictor variables can be related to one another in non-additive ways. The basic logic of statistical conditioning does not change in those cases, but the details definitely do, and these residual plots cease to be useful. Luckily there are other ways to understand a model. That’s where we turn next.
Residuals are parameters, not data
There is a tradition, especially in parts of biology, of using residuals from one model as data in another model. For example, a biologist might regress brain size on body size and then use the brain size residuals as data in another model. This procedure is always a mistake. Residuals are not known. They are parameters, variables with unobserved values. Treating them as known values throws away uncertainty. The right way to adjust for body size is to include it in the same model, preferably a model designed in light of an explicit causal model.
Posterior Prediction Plots
It’s important to check the model’s implied predictions against the observed data. This is what we did earlier, when we simulated globe tosses, averaging over the posterior, and comparing the simulated results to the observed. These kinds of checks are useful in many ways. For now, we’ll focus on two uses.
Did the model correctly approximate the posterior distribution? Errors can be more easily diagnosed by comparing implied predictions to the raw data. Some caution is required, because not all models try to exactly match the sample. But even then, you’ll know what to expect from a successful approximation. You’ll see some examples later.
How does the model fail? Models are useful fictions. So they always fail in some way. Sometimes, a model fits correctly but is still so poor for our purposes that it must be discarded. More often, a model predicts well in some respects, but not in others. By inspecting the individual cases where the model makes poor predictions, you might get an idea of how to improve it. The difficulty is that this process is essentially creative and relies upon the analyst’s domain expertise. No robot can (yet) do it for you. It also risks chasing noise, a topic we’ll focus on in later chapters.
How could we produce a simple posterior predictive check in the divorce example? Let’s begin by simulating predictions, averaging over the posterior.
# call laplace_sample to directly obtain mu (linear model data)# Not specifying new data so it uses original datamu = m5_3_model.laplace_sample(draws=1000).stan_variable('mu') # summarize samples across casesmu_mean, mu_plo, mu_phi = utils.precis(mu)# simulate observations. Again no new data, so uses original dataD_sim = m5_3_model.sim(n=1000)['y_rep']D_mean, D_plo, D_phi = utils.precis(D_sim)
This code is similar to what you’ve seen before, but now using the original observed data. For multivariate models, there are many different ways to display these simulations. The simplest is to just plot predictions against observed. This code will do that, and then add a line to show perfect prediction and line segments for the confidence interval of each prediction:
Posterior predictive plot for the multivariate divorce model, m5_3. The horizontal axis is the observed divorce rate in each State. The vertical axis is the model’s posterior predicted divorce rate, given each State’s median age at marriage and marriage rate. The black line segments are 89% compatibility intervals. The diagonal line shows where posterior predictions exactly match the sample.
It’s easy to see from this arrangement of the simulations that the model under-predicts for States with very high divorce rates while it over-predicts for States with very low divorce rates. That’s normal. This is what regression does—it is skeptical of extreme values, so it expects regression towards the mean. But beyond this general regression to the mean, some States are very frustrating to the model, lying very far from the diagonal. I’ve labeled some points like this, including Idaho (ID) and Utah (UT), both of which have much lower divorce rates than the model expects them to have.
What is unusual about Idaho and Utah? Both of these States have large proportions of members of the Church of Jesus Christ of Latter-day Saints. Members of this church have low rates of divorce, wherever they live. This suggests that having a finer view on the demographic composition of each State, beyond just median age at marriage, would help.
Stats, huh, yeah what is it good for?
Often people want statistical modeling to do things that statistical modeling cannot do. For example, we’d like to know whether an effect is “real” or rather spurious. Unfortunately, modeling merely quantifies uncertainty in the precise way that the model understands the problem. Usually answers to large world questions about truth and causation depend upon information not included in the model. For example, any observed correlation between an outcome and predictor could be eliminated or reversed once another predictor is added to the model. But if we cannot think of the right variable, we might never notice. Therefore all statistical models are vulnerable to and demand critique, regardless of the precision of their estimates and apparent accuracy of their predictions. Rounds of model criticism and revision embody the real tests of scientific hypotheses. A true hypothesis will pass and fail many statistical “tests” on its way to acceptance.
Simulating Spurious Association
One way that spurious associations between a predictor and outcome can arise is when a truly causal predictor, call it \(x_{real}\), influences both the outcome, \(y\), and a spurious predictor, \(x_{spur}\). This can be confusing, however, so it may help to simulate this scenario and see both how the spurious data arise and prove to yourself that multiple regression can reliably indicate the right predictor, \(x_{real}\). So here’s a very basic simulation:
N =100# Number of Cases x_real = np.random.normal(size=100) # x_real as Gaussian with mean 0 and stddev 1x_spur = np.random.normal(loc=x_real, size=100) # x_spur as Gaussian with mean=x_realy = np.random.normal(loc=x_real, size=100) # y as Gaussian with mean=x_realdf = pd.DataFrame({'y': y, 'x_real': x_real, 'x_spur': x_spur}) # bind all together in data frame
Now the data frame d has 100 simulated cases. Because x_real influences both y and x_spur, you can think of x_spur as another outcome of x_real, but one which we mistake as a potential predictor of y. As a result, both \(x_{real}\) and \(x_{spur}\) are correlated with y. You can see this in the scatterplots. But when you include both \(x\) variables in a linear regression predicting \(y\), the posterior mean for the association between \(y\) and \(x_{spur}\) will be close to zero.
A second sort of inferential plot displays the causal implications of the model. I call these plots counterfactual, because they can be produced for any values of the predictor variables you like, even unobserved combinations like very high median age of marriage and very high marriage rate. There are no States with this combination, but in a counterfactual plot, you can ask the model for a prediction for such a State, asking questions like “What would Utah’s divorce rate be, if it’s median age at marriage were higher?” Used with clarity of purpose, counterfactual plots help you understand the model, as well as generate predictions for imaginary interventions and compute how much some observed outcome could be attributed to some cause.
Note that the term “counterfactual” is highly overloaded in statistics and philosophy. It hardly ever means the same thing when used by different authors. Here, I use it to indicate some computation that makes use of the structural causal model, going beyond the posterior distribution. But it could refer to questions about both the past and the future.
The simplest use of a counterfactual plot is to see how the outcome would change as you change one predictor at a time. If some predictor X took on a new value for one or more cases in our data, how would the outcome Y have changed? Changing just one predictor X might also change other predictors, depending upon the causal model. Suppose for example that you pay young couples to postpone marriage until they are 35 years old. Surely this will also decrease the number of couples who ever get married—some people will die before turning 35, among other reasons—decreasing the overall marriage rate. An extraordinary and evil degree of control over people would be necessary to really hold marriage rate constant while forcing everyone to marry at a later age.
So let’s see how to generate plots of model predictions that take the causal structure into account. The basic recipe is:
Pick a variable to manipulate, the intervention variable.
Define the range of values to set the intervention variable to.
For each value of the intervention variable, and for each sample in posterior, use the causal model to simulate the values of other variables, including the outcome.
In the end, you end up with a posterior distribution of counterfactual outcomes that you can plot and summarize in various ways, depending upon your goal.
Let’s see how to do this for the divorce model. Again we take this DAG as given:
pgm1.render()plt.gca().invert_yaxis();
To simulate from this, we need more than the DAG. We also need a set of functions that tell us how each variable is generated. For simplicity, we’ll use Gaussian distributions for each variable, just like in model m5_3. But model m5_3 ignored the assumption that \(A\) influences \(M\). We didn’t need that to estimate \(A → D\). But we do need it to predict the consequences of manipulating \(A\), because some of the effect of \(A\) acts through \(M\).
To estimate the influence of \(A\) on \(M\), all we need is to regress \(A\) on \(M\). There are no other variables in the DAG creating an association between \(A\) and \(M\). We can just add this regression to the model, running two regressions at the same time:
m5_3_A ='''data { int<lower=0> N; // Number of observations vector[N] M; // Marriage rate vector[N] D; // Divorce rate vector[N] A; // Median age at marriage int<lower=0> N_tilde; // Number of counterfactual simulations vector[N_tilde] A_seq; // Counterfactual A values for A -> M -> D path vector[N_tilde] M_seq; // Counterfactual M values for direct M -> D path}parameters { real a; real bA; real bM; real<lower=0> sigma; real aM; real bAM; real<lower=0> sigma_M;}transformed parameters { vector[N] mu; mu = a + bA * A + bM * M; // A -> D <- M vector[N] mu_M; mu_M = aM + bAM * A; // A -> M}model { // Priors a ~ normal(0, 0.2); bA ~ normal(0, 0.5); bM ~ normal(0, 0.5); sigma ~ exponential(1); aM ~ normal(0, 0.2); bAM ~ normal(0, 0.5); sigma_M ~ exponential(1); // Likelihood D ~ normal(mu, sigma); M ~ normal(mu_M, sigma_M);}generated quantities { vector[N_tilde] M_tilde; vector[N_tilde] D_tilde; vector[N_tilde] D_tilde_M; // Simulating M first (A -> M) for (i in 1:N_tilde) { M_tilde[i] = normal_rng(aM + bAM * A_seq[i], sigma_M); } // Simulating D given new M (A -> D <- M) for (i in 1:N_tilde) { D_tilde[i] = normal_rng(a + bA * A_seq[i] + bM * M_tilde[i], sigma); } // Simulating D for directly controlled M for (i in 1:N_tilde) { D_tilde_M[i] = normal_rng(a + bA * A_seq[i] + bM * M_seq[i], sigma); }}'''seq = np.linspace(-2, 2, 30)data = {'N': len(d),'A': d.A.tolist(),'M': d.M.tolist(),'D': d.D.tolist(),'N_tilde': len(seq),'A_seq': seq.tolist(),'M_seq': seq.tolist()}m5_3_A_model = utils.StanQuap('stan_models/m5_3_A', m5_3_A, data=data, algorithm='LBFGS', generated_var=['M_tilde', 'D_tilde', 'D_tilde_M', 'mu', 'mu_M'])
m5_3_A_model.precis().round(2)
Mean
StDev
5.5%
94.5%
Parameter
a
0.00
0.10
-0.16
0.16
bA
-0.61
0.15
-0.85
-0.37
bM
-0.07
0.15
-0.31
0.18
sigma
0.79
0.08
0.66
0.91
aM
0.00
0.09
-0.14
0.14
bAM
-0.69
0.10
-0.85
-0.54
sigma_M
0.68
0.07
0.57
0.79
Look at the summary. You’ll see that \(M\) and \(A\) are strongly negatively associated. If we interpret this causally, it indicates that manipulating \(A\) reduces \(M\).
The goal is to simulate what would happen, if we manipulate \(A\). During model set-up, we have defined a range of values for \(A\) in A_seq which is a list of 30 imaginary interventions, ranging from 2 standard deviations below and 2 above the mean. Now we can use sim, which you met in the previous chapter, to simulate observations from model m5_3_A. But this time we’ll tell it to simulate both \(M\) and \(D\), in that order. Why in that order? Because we have to simulate the influence of \(A\) on \(M\) before we simulate the joint influence of \(A\) and \(M\) on \(D\). Note that this logic was also coded right into our Stan model’s generated quantities block, therefore the model setup tells it both which observables to simulate and in which order.
# simulate M_tilde and then D_tilde, using A_seqs = m5_3_A_model.sim(n=1000, select=['M_tilde', 'D_tilde'])
That’s all there is to it. But do at least glance at the ‘Simulating counterfactuals’ box at the end of this section, where we walk through the individual steps, so we can perform this kind of counterfactual simulation for any model fit with any software. Now to plot the predictions:
Counterfactual plots for the multivariate divorce model, m5_3_A. These plots visualize the predicted effect of manipulating age at marriage \(A\) on divorce rate \(D\). Left: Total causal effect of manipulating \(A\) (horizontal) on \(D\). This plot contains both paths, \(A → D\) and \(A → M → D\). Right: Simulated values of M show the estimated influence \(A →M\).
Left Plot: This predicted trend in \(D\) includes both paths: \(A → D\) and \(A → M → D\). We found previously that \(M → D\) is very small, so the second path doesn’t contribute much to the trend. But if \(M\) were to strongly influence \(D\), the code above would include the effect. The counterfactual simulation also generated values for \(M\). These are shown on the right. The object s from the code above includes these simulated \(M\) values.
Of course these calculations also permit numerical summaries. For example, the expected causal effect of increasing median age at marriage from 20 to 30 is:
# A_seq standardized to mean 26.1 and std dev 1.24data = {'N': len(d),'A': d.A.tolist(),'M': d.M.tolist(),'D': d.D.tolist(),'N_tilde': 2,'A_seq': ((np.array([20,30])-26.1) /1.24).tolist(),'M_seq': [0,0]}s2 = m5_3_A_model.sim(data=data, n=1000, select=['M_tilde', 'D_tilde'])(s2['D_tilde'][:, 1] - s2['D_tilde'][:, 0]).mean()
np.float64(-4.54662228998)
This is a huge effect of four and one half standard deviations, probably impossibly large.
The trick with simulating counterfactuals is to realize that when we manipulate some variable \(X\), we break the causal influence of other variables on \(X\). This is the same as saying we modify the DAG so that no arrows enter \(X\). Suppose for example that we now simulate the effect of manipulating \(M\). This implies the DAG:
The arrow \(A → M\) is deleted, because if we control the values of \(M\), then \(A\) no longer influences it. It’s like a perfectly controlled experiment. Now we can modify the code above to simulate the counterfactual result of manipulating \(M\). We’ll simulate a counterfactual for an average state, with \(A =0\), and see what changing \(M\) does.
data = {'N': len(d),'A': d.A.tolist(),'M': d.M.tolist(),'D': d.D.tolist(),'N_tilde': len(seq),'A_seq': np.zeros(30).tolist(),'M_seq': seq.tolist()}s = m5_3_A_model.sim(data=data, n=1000, select=['D_tilde_M'])def plot_counterfactual(): D_mean, D_plo, D_phi = utils.precis(s['D_tilde_M']) plt.plot(seq, D_mean) plt.fill_between(seq, D_plo, D_phi, alpha=0.2) plt.ylim(-2.1, 2) plt.xlabel('Manipulated M'), plt.ylabel('Counterfactual D') plt.title('Total Counterfactual Effect of M on D')plt.clf()plot_counterfactual()
The counterfactual effect of manipulating marriage rate \(M\) on divorce rate \(D\). Since \(M →D\) was estimated to be very small, there is no strong trend here. By manipulating \(M\), we break the influence of \(A\) on \(M\), and this removes the association between \(M\) and \(D\).
We only simulate D now (we only retrieve D_tilde_M from our model). We don’t simulate A, because M doesn’t influence it. We show this plot in the plot above. This trend is less strong, because there is no evidence for a strong influence of \(M\) on \(D\).
In more complex models with many potential paths, the same strategy will compute counterfactuals for an exposure of interest. But as we’ll see in later examples, often it is simply not possible to estimate a plausible, un-confounded causal effect of some exposure \(X\) on some outcome \(Y\). But even in those cases, there are still important counterfactuals to consider. So we’ll return to this theme in the future.
Simulating counterfactuals
The example in this section programmed code into the generated quantities block and used .sim() method to access those simulated scenarios. But simulating counterfactuals on your own is not hard. It just uses the model definition. Assume we’ve already fit model m5_3_A, the model that includes both causal paths \(A → D\) and \(A →M →D\). We define a range of values that we want to assign to \(A\):
A_seq = np.linspace(-2, 2, 30)
Next we need to extract the posterior samples, because we’ll simulate observations for each set of samples. Then it really is just a matter of using the model definition with the samples, as in previous examples. The model defines the distribution of \(M\). We just convert that definition to the corresponding simulation function, which is stats.norm.rvs in this case:
post = m5_3_A_model.extract_samples(n=1000)M_sim = stats.norm.rvs(loc=post['aM'] + post['bAM'] * A_seq.reshape(-1,1), scale=post['sigma_M'])M_sim.T
The linear model inside stats.norm.rvs comes right out of the model definition (M ~ normal(mu_M, sigma_M)). This produces a matrix of values, with samples in rows and cases corresponding to the values in A_seq in the columns. Now that we have simulated values for \(M\), we can simulate \(D\) too:
If you plot A_seq against the row means of D_sim, you’ll see the same result as before. In complex models, there might be many more variables to simulate. But the basic procedure is the same.
def plot_counterfactual(): D_mean, D_phi, D_plo = utils.precis(D_sim.T) plt.plot(A_seq, D_mean) plt.fill_between(A_seq, D_plo, D_phi, alpha=0.2) plt.xlabel('Manipulated A'), plt.ylabel('Counterfactual D') plt.title('Total Counterfactual Effect of A on D')plt.clf()plot_counterfactual()