Same model, better shape: why centering improves MCMC

The Emergency departments leading the transformation of Alzheimer’s and dementia care (ED-LEAD) study, which I have written about in the past, is approaching the end of its third year. This multifactorial design evaluates three independent, yet potentially synergistic, interventions aimed at improving care for persons living with dementia (PLWD) and their caregivers.

To estimate intervention effects, we are using what I’ve called the HEx-factor model, a Bayesian hierarchical exchangeable factorial model. The original plan was to conduct all analyses using Stan. However, we’ve run into a bit of a snafu. I’ve been working through the problem, and thought I’d share here.

The challenge turns out to be a computational one. Because the ED-LEAD analyses must be conducted on National Institute on Aging (NIA) Data LINKAGE servers, we are working in a somewhat restricted software environment, at least with respect to Bayesian data analysis. In particular, we have not been able to install or run Stan, which was our analytic engine of choice. This forced us to consider alternatives, and we turned to JAGS, which is available in the Linkage environment and certainly is well-suited for Bayesian hierarchical modeling.

At first glance, this might seem like a straightforward substitution. Both Stan and JAGS allow us to specify the same likelihood and priors. However, I quickly noticed that the models were not performing as well in JAGS as they had in Stan. It turns out that the samplers used in JAGS are more sensitive to posterior dependence than the Hamiltonian Monte Carlo (HMC) methods implemented in Stan.

I set out to understand and fix the problem, and found that a simple reparameterization—re-coding the binary treatment indicators—made a substantial difference. With this change, the JAGS sampler was able to explore the posterior distribution much more efficiently, yielding results comparable to those obtained with Stan.

To understand why this happens, I ran a series of simple simulations comparing the original and reparameterized versions of a basic two-way factorial model. That is what I present here.

The setup

In models with binary predictors and interactions, it turns out that centering can have a surprisingly large impact on computation, even though it does not change the underlying model. To see this clearly, I’ll start with a simple two-factor logistic model: \[ \text{logit}\big[P(Y=1)\big] = \alpha+ \beta_a A + \beta_b B + \beta_{ab}AB \] where \(A\) and \(B\) are binary treatment indicators. I’ll compare this to the algebraically equivalent centered version: \[ \text{logit}\big[P(Y=1)\big] = \alpha^*+ \gamma_a A^* + \gamma_b B^* + \gamma_{ab}A^*B^* \] where

\[ A^* = A - 0.5, \ \ \ B^* = B - 0.5. \]

The scientific model is unchanged. The question is whether the sampler behaves differently.

Log odds ratios under each parameterization

With 0/1 coding, the log-odds ratio for \(A\) alone (that is, when \(B=0\)) is simply \[ \begin{align*} \text{lOR}_a(B=0) &= (\alpha + \beta_a \cdot 1 + \beta_b \cdot 0 + \beta_{ab} \cdot 0) -(\alpha + \beta_a \cdot 0 + \beta_b \cdot 0 + \beta_{ab} \cdot 0) \\ & = \beta_a \end{align*} \] Analogously, the log-odds ratio for \(B\) alone is \(\beta_b\). And if we want to compare the combination of both \(A=1\) and \(B=1\) to the case where neither is activated, then \[ \begin{align*} \text{lOR}_{ab} &= (\alpha + \beta_a \cdot 1 + \beta_b \cdot 1 + \beta_{ab} \cdot 1) \\ &\quad - (\alpha + \beta_a \cdot 0 + \beta_b \cdot 0 + \beta_{ab} \cdot 0) \\ &= \beta_a + \beta_b + \beta_{ab} \end{align*} \] If instead we center the predictors, defining \(A^* = A - 0.5\) and \(B^* = B - 0.5\), then the log-odds ratio of exposure to \(A\) without exposure to \(B\) relative to exposure to neither is \[ \begin{align*} \text{lOR}_a(B=0) &= (\alpha^* + \gamma_a(0.5) + \gamma_b(-0.5) + \gamma_{ab}(0.5)(-0.5)) \\ &\quad - (\alpha^* + \gamma_a(-0.5) + \gamma_b(-0.5) + \gamma_{ab}(-0.5)(-0.5)) \\ &= (\alpha^* + 0.5\gamma_a - 0.5\gamma_b - 0.25\gamma_{ab}) \\ &\quad - (\alpha^* - 0.5\gamma_a - 0.5\gamma_b + 0.25\gamma_{ab}) \\ &= \gamma_a - 0.5\gamma_{ab}. \end{align*} \] Using the same logic we can show that \[ \text{lOR}_{b} = \gamma_{b} - 0.5 \gamma_{ab} \] and \[ \text{lOR}_{ab} = \gamma_a + \gamma_b. \]

Bayesian models using JAGS

The Bayesian model is a simple logistic regression with an interaction term:

\[ \begin{align*} Y_i &\sim \text{Bernoulli}(p_i), \\ \text{logit}(p_i) &= \alpha + \beta_a A_i + \beta_b B_i + \beta_{ab} A_i B_i, \end{align*} \] Here are the prior distribution assumptions, using variance-based notation to align with JAGS, which parameterizes normal distributions in terms of precision \[ \begin{align*} \alpha &\sim \mathcal{N}(0, 0.25^{-1}), \\ \beta_a &\sim \mathcal{N}(0, 0.25^{-1}), \\ \beta_b &\sim \mathcal{N}(0, 0.25^{-1}), \\ \beta_{ab} &\sim \mathcal{N}(0, 25^{-1}). \end{align*} \] The centered model is similar, except that we replace the coefficients with \(\alpha^*\) as well as \(\gamma_a\), \(\gamma_b\), \(\gamma_{ab}\), and define the predictors in terms of centered versions of \(A\) and \(B\).

Simulations

Before we get started on the simulations, we need to load the necessary libraries and set the seed in case you want to replicate these results:

library(simstudy)
library(data.table)
library(ggplot2)
library(rjags)
library(coda)
library(posterior)
library(broom)
library(gt)

RNGkind("Mersenne-Twister", "Inversion", "Rejection")
set.seed(824)

Creating a single data set

Here is the data generation process for a single data set. The outcome \(Y\) is generated using the binary parameterization of \(A\) and \(B\):

s_gen <- function(n = 2000,
                    alpha = -0.8,
                    beta_a = 0.5,
                    beta_b = 0.9,
                    beta_ab = -0.3) {
  
  def <- 
    defData(varname = "A", formula = 0.5, dist = "binary") |>
    defData(varname = "B", formula = 0.5, dist = "binary") |>
    defData(varname = "AB", formula = "A*B", dist = "nonrandom") |>
    defData(varname = "A_c", formula = "A - 0.5", dist = "nonrandom") |>
    defData(varname = "B_c", formula = "B - 0.5", dist = "nonrandom") |>
    defData(varname = "AB_c", formula = "A_c * B_c", dist = "nonrandom") |>
    defData(
      varname = "Y", 
      formula = "..alpha + ..beta_a * A + ..beta_b * B + ..beta_ab * AB",
      dist = "binary", link = "logit"
    )
    
  genData(n, def)
  
}

dd <- s_gen()

The two parameterizations fit the same model

First, here is the frequentist check of both models. The fitted probabilities are identical, even though the coefficients differ.

fit_01 <- glm(Y ~ A * B, data = dd, family = binomial)
fit_c  <- glm(Y ~ A_c * B_c, data = dd, family = binomial)

tidy(fit_01)
## # A tibble: 4 × 5
##   term        estimate std.error statistic  p.value
##   <chr>          <dbl>     <dbl>     <dbl>    <dbl>
## 1 (Intercept)   -1.03     0.0995    -10.4  3.90e-25
## 2 A              0.835    0.135       6.18 6.34e-10
## 3 B              1.14     0.134       8.46 2.73e-17
## 4 A:B           -0.670    0.186      -3.61 3.11e- 4
tidy(fit_c)
## # A tibble: 4 × 5
##   term        estimate std.error statistic  p.value
##   <chr>          <dbl>     <dbl>     <dbl>    <dbl>
## 1 (Intercept)   -0.212    0.0464     -4.57 4.91e- 6
## 2 A_c            0.501    0.0929      5.39 7.05e- 8
## 3 B_c            0.802    0.0929      8.63 6.07e-18
## 4 A_c:B_c       -0.670    0.186      -3.61 3.11e- 4

From the 0/1-coded model, \(\text{lOR}_a = 0.835\), \(\text{lOR}_b = 1.14\), and \(\text{lOR}_{ab} = 0.835 + 1.14 - 0.67 = 1.305.\)

From the centered model, \[ \text{lOR}_a = 0.501 + 0.5*0.670 = 0.836 \] \[ \text{lOR}_b = 0.802 + 0.5*0.670 = 1.137 \] \[ \text{lOR}_{ab} = 0.501 + 0.802 = 1.303 \] So the coefficients themselves change under centering, but the underlying treatment contrasts do not.

Specifying the Bayesian models in JAGS

Now we see that we can recover the same treatment contrasts using two different Bayesian models, though computational performance will be improved with centering.

Here is the JAGS code for each model:

model_01 <- "
model {
  for (i in 1:N) {
    Y[i] ~ dbern(p[i])
    logit(p[i]) <- alpha + beta_a * A[i] + beta_b * B[i] + beta_ab * AB[i]
  }
  
  alpha   ~ dnorm(0, 0.25)
  beta_a  ~ dnorm(0, 0.25)
  beta_b  ~ dnorm(0, 0.25)
  beta_ab ~ dnorm(0, 25)
}
"

model_c <- "
model {
  for (i in 1:N) {
    Y[i] ~ dbern(p[i])
    logit(p[i]) <- alpha + gamma_a * A_c[i] + gamma_b * B_c[i] + gamma_ab * AB_c[i]
  }
  
  alpha   ~ dnorm(0, 0.25)
  gamma_a  ~ dnorm(0, 0.25)
  gamma_b  ~ dnorm(0, 0.25)
  gamma_ab ~ dnorm(0, 25)
}
"

Fitting the models

The function fit_jags fits one of the two models just described:

fit_jags <- function(dat, model_string, centered = FALSE,
                     n_chains = 3, burn = 2000, n_iter = 5000) {
  
  # jdat <- as.list(dat[, .(Y, A, B, AB, A_c, B_c, AB_c)])
  if (centered) {
    jdat <- as.list(dat[, .(Y, A_c, B_c, AB_c)])
    vars <- c("alpha", "gamma_a", "gamma_b", "gamma_ab")
  } else {
    jdat <- as.list(dat[, .(Y, A, B, AB)])
    vars <- c("alpha", "beta_a", "beta_b", "beta_ab")
  }
  jdat$N <- nrow(dat)
  
  mod <- jags.model(
    textConnection(model_string),
    data = jdat,
    n.chains = n_chains,
    quiet = TRUE
  )
  
  update(mod, burn, progress.bar = "none")
  
  samp <- coda.samples(
    mod,
    variable.names = vars,
    n.iter = n_iter,
    progress.bar = "none"
  )
  
  samp
}

Now, we can fit the models, collect the diagnostic data, and take a look at the results:

samp_01 <- fit_jags(dd, model_01, centered = FALSE)
samp_c  <- fit_jags(dd, model_c, centered = TRUE)

diag_tbl <- function(samp, model_name) {
  post <- as_draws_df(samp)
  summ <- summarise_draws(post)
  out <- as.data.table(summ)
  out[, model := model_name]
  out[]
}

diag_01 <- diag_tbl(samp_01, "0/1-coded")
diag_c  <- diag_tbl(samp_c, "centered")

Here are the summary statistics of the posterior distribution as well as the computational diagnostics:

Parameter Mean Median SD MAD 5th %tile 95th %tile R-hat ESS (bulk) ESS (tail)
0/1-coded
alpha -0.937 -0.937 0.092 0.093 -1.088 -0.786 1.001 1359.836 2862.560
beta_a 0.665 0.665 0.117 0.118 0.472 0.855 1.001 1612.670 3084.048
beta_ab -0.352 -0.352 0.137 0.139 -0.579 -0.128 1.001 1699.870 3617.854
beta_b 0.969 0.968 0.118 0.119 0.777 1.164 1.001 1430.007 3298.993
centered
alpha -0.210 -0.210 0.047 0.047 -0.285 -0.133 1.000 9278.897 9114.590
gamma_a 0.492 0.492 0.094 0.094 0.337 0.647 1.000 9316.962 8853.810
gamma_ab -0.361 -0.362 0.135 0.134 -0.582 -0.140 1.000 9167.199 9311.161
gamma_b 0.795 0.795 0.093 0.095 0.642 0.946 1.001 8969.964 8821.275

There are a few things to notice here. First, the Bayesian estimates for both the 0/1-coded and centered data are closer to zero than the GLM estimates above. The shrinkage is particularly large for the interaction term, because we placed much more restrictive priors on \(\beta_{ab}\) and \(\gamma_{ab}\). This is what we would expect as the prior is pulling the interaction toward zero.

Second, if we compare the two parameterizations, we see that the R-hat—essentially a measure of whether the chains have converged to the same distribution—is slightly lower for the centered data. There isn’t much to make of the difference here (both are very close to 1), but it does suggest slightly more stable behavior for the centered parameterization.

The biggest impact is on the bulk effective sample size (ESS), which reflects how much independent information the chains contain after accounting for autocorrelation. Even though we ran the same number of iterations, the centered model yields far larger ESS values, indicating much better mixing. The sampler is exploring the posterior much more efficiently under the centered parameterization, and in this case the improvement is quite dramatic. Importantly, these differences have nothing to do with the models themselves since the likelihood is unchanged. Rather, it reflects how easy it is for the sampler to navigate the posterior surface when the data are centered.

A comparison of the trace plots reinforces the stability that centering the data provides. The traces. for the 0/1-coded data (on the left) are a bit more irregular, suggesting less efficient exploration of the posterior. In contrast, the centered parameterization produces tighter, more stable traces with less autocorrelation (on the right), indicating that the chains are mixing more effectively. This aligns with the much larger effective sample sizes observed for the centered model.

Finally, we compare the estimation of the log-odds ratios for the two models, just as we did before with the GLM models, and it is clear that the two Bayesian models also provide the same estimates of the contratsts:

get_lor_summary <- function(samp, model_name) {
  dt <- as.data.table(as_draws_df(samp))
  
  if (model_name == "0/1-coded") {
    dt[, lOR_A := beta_a]
    dt[, lOR_B := beta_b]
    dt[, lOR_AB := beta_a + beta_b + beta_ab]
  } else {
    dt[, lOR_A := gamma_a - 0.5 * gamma_ab]
    dt[, lOR_B := gamma_b - 0.5 * gamma_ab]
    dt[, lOR_AB := gamma_a + gamma_b]
  }
  
  dt[, .(
    mean_A = mean(lOR_A),
    mean_B = mean(lOR_B),
    mean_AB = mean(lOR_AB),
    sd_A = sd(lOR_A),
    sd_B = sd(lOR_B),
    sd_AB = sd(lOR_AB)
  )]
}

lor_01 <- get_lor_summary(samp_01, "0/1-coded")
lor_c  <- get_lor_summary(samp_c,  "centered")
log OR A log OR B log OR AB
0/1-coded 0.672 (0.115) 0.973 (0.116) 1.286 (0.132)
centered 0.672 (0.116) 0.975 (0.117) 1.288 (0.134)

A larger simulation experiment

A single data set can be misleading. So next I’ll repeat this 500 times and compare the two parameterizations across simulations. Each iteration, I generate a data set with 2000 observations, I fit each model—the one with 0/1-coding and the other with centered coding—using JAGS, and collect summary data of the posteriors from each model JAGS: mean, median, standard deviation, median absolute deviation, 5th percentile, 95th percentile, R-hat, bulk ESS, and tail ESS.

one_run <- function(
  n = 2000,
  truth = c(alpha = -0.8, beta_a = 0.5, beta_b = 0.9, beta_ab = -0.3),
  n_chains = 3,
  burn = 1000,
  n_iter = 3000
) {
  
  dd <- s_gen(
    n = n,
    alpha = truth["alpha"],
    beta_a = truth["beta_a"],
    beta_b = truth["beta_b"],
    beta_ab = truth["beta_ab"]
  )
  
  samp_01 <- fit_jags(
    dd, model_01, centered = FALSE, 
    n_chains = n_chains, burn = burn, n_iter = n_iter
  )
  
  samp_c <- fit_jags(
    dd, model_c, centered = TRUE, 
    n_chains = n_chains, burn = burn, n_iter = n_iter)
  
  get_metrics <- function(samp, model_name) {
    post <- as_draws_df(samp)
    summ <- as.data.table(summarise_draws(post))
    summ[, model := model_name]
    summ[]
  }
  
  out <- rbindlist(list(
    get_metrics(samp_01, "0/1-coded"),
    get_metrics(samp_c,  "centered")
  ))
  
  out[]
}

nsim <- 500

sim_res <- rbindlist(mclapply(seq_len(nsim), function(i) {
  out <- one_run()
  out[, sim := i]
  out[]
}, mc.cores = 5))

Earlier we saw for a single data set, there was not much difference in R-hat (essentially a measure of whether the chains have converged to the same distribution) between the two models. However, over repeated data sets, a more interesting picture emerges. The figure below shows that while R-hat for the 0/1-coding model is quite low, R-hat for the centered-coding is lower still, and much more consistent, suggesting that mixing is stronger in the centered model.

The next figure also confirms what we saw earlier. This shows the distribution of ratios of bulk ESS in the centered model compared to the 0/1-coding model. If the two models had the same effective sample size, we would expect those ratios to cluster near one. However, they are all mostly greater than five, confirming what we saw for the individual data set.

The key issue is posterior dependence among parameters: when parameters are highly correlated, the sampler will explore narrower regions in the posterior, which slows mixing.

Understanding what is driving the performance

To better understand this, we can look directly at the dependence structure of the posterior draws. Correlation plots (where each point is a draw from the posterior) help explain what is driving these differences in performance. Under the 0/1-coded parameterization, the posterior exhibits strong dependence among parameters. Several pairs of coefficients show substantial correlations, reflecting the fact that different combinations of parameters can produce similar fitted values. In geometric terms, the joint posterior has an elongated, highly correlated structure. This is evident in the pairwise scatter plots, where draws fall along narrow, tilted bands rather than forming roughly circular clouds.

This geometry makes life difficult for the sampler. Exploring a narrower region requires smaller, correlated steps, which leads to high autocorrelation and, ultimately, low effective sample sizes.

In contrast, the centered parameterization produces a posterior that is nearly uncorrelated. The coefficients capture more distinct aspects of the model, and the resulting posterior is much more spherical. This greatly simplifies the exploration of the parameter space, allowing the sampler to move more freely.

The key point is that centering does not change the model or the scientific conclusions. It changes the geometry of the posterior distribution, and that change can have a dramatic impact on computational performance. In effect, centering makes the parameters closer to orthogonal in the posterior, reducing interference among them and improving both statistical and computational behavior.

In the ED-LEAD study, where we are fitting hierarchical factorial models with multiple intervention components, this shift in parameterization is critical. Centering the treatment indicators leads to more stable estimation and far more efficient sampling, which is particularly important given our reliance on JAGS. Unlike Hamiltonian Monte Carlo (as implemented in Stan), which can handle correlated posteriors more effectively, the Gibbs and Metropolis-based updates used by JAGS are much more sensitive to posterior dependence. Improving the geometry of the posterior seems to be critical for good performance in this setting.

Support:

This work was supported in part by the National Institute on Aging (NIA) of the National Institutes of Health under Award Number U19AG078105, which funds the Emergency departments leading the transformation of Alzheimer’s and dementia care (ED-LEAD) study. The author, the leader of the Statistics Analysis Core, was the sole writer of this blog post and has no conflicts. The content is solely the responsibility of the author and does not necessarily represent the official views of the National Institutes of Health.

comments powered by Disqus