Would you like all your posteriors in one plot?

A colleague reached out to me earlier this week with a plotting question. They had fit a series of Bayesian models, all containing a common parameter of interest. They knew how to plot their focal parameter one model at a time, but were stumped on how to combine the plots across models into a seamless whole. It reminded me a bit of this gif

which I originally got from Jenny Bryan’s great talk, Behind every great plot there’s a great deal of wrangling.

The goal of this post is to provide solutions. We’ll practice a few different ways you can combine the posterior samples from your Bayesian models into a single plot. As usual, we’ll be fitting our models with brms, wrangling with packages from the tidyverse, and getting a little help from the tidybayes package.

I make assumptions.

For this post, I’m presuming you are familiar Bayesian regression using brms. I’m also assuming you’ve coded using some of the foundational functions from the tidyverse. If you’d like to firm up your foundations a bit, check out these resources.

  • To learn about Bayesian regression, I recommend the introductory text books by either McElreath (here) or Kruschke (here). Both authors host blogs (here and here, respectively). If you go with McElreath, do check out his online lectures and my ebooks where I translated his text to brms and tidyverse code (here and here). I have a similar ebook translation for Kruschke’s text (here).
  • For even more brms-related resources, you can find vignettes and documentation here.
  • For tidyverse introductions, your best bets are R4DS and The tidyverse style guide.

Same parameter, different models

Let’s load our primary statistical packages.

library(tidyverse)
library(brms)
library(tidybayes)

Simulate \(n = 150\) draws from the standard normal distribution.

n <- 150

set.seed(1)
d <-
  tibble(y = rnorm(n, mean = 0, sd = 1))

head(d)
## # A tibble: 6 x 1
##        y
##    <dbl>
## 1 -0.626
## 2  0.184
## 3 -0.836
## 4  1.60 
## 5  0.330
## 6 -0.820

Here we’ll fit three intercept-only models for y. Each will follow the form

\[ \begin{align*} y_i & \sim \text{Normal} (\mu, \sigma) \\ \mu & = \beta_0 \\ \beta_0 & \sim \text{Normal} (0, x) \\ \sigma & \sim \text{Student-t}(3, 0, 10) \end{align*} \]

where \(\beta_0\) is the unconditional intercept (i.e., an intercept not conditioned on any predictors). We will be fitting three alternative models. All will have the same prior for \(\sigma\), \(\text{Student-t}(3, 0, 10)\), which is the brms default in this case. [If you’d like to check, use the get_prior() function.] The only way the models will differ is by their prior on the intercept \(\beta_0\). By model, those priors will be

  • fit1: \(\beta_0 \sim \text{Normal} (0, 10)\),
  • fit2: \(\beta_0 \sim \text{Normal} (0, 1)\), and
  • fit3: \(\beta_0 \sim \text{Normal} (0, 0.1)\).

So if you were wondering, the \(x\) in the \(\beta_0 \sim \text{Normal} (0, x)\) line, above, was a stand-in for the varying hyperparameter.

Here we fit the models in bulk.

fit1 <-
  brm(data = d,
      family = gaussian,
      y ~ 1,
      prior(normal(0, 10), class = Intercept),
      seed = 1)

fit2 <-
  update(fit1,
         prior = prior(normal(0, 1), class = Intercept),
         seed = 1)

fit3 <-
  update(fit1,
         prior = prior(normal(0, 0.1), class = Intercept),
         seed = 1)

Normally we’d use plot() to make sure the chains look good and then use something like print() or posterior_summary() to summarize the models’ results. I’ve checked and they’re all fine. For the sake of space, let’s press forward.

If you were going to plot the results of an individual fit using something like the tidybayes::stat_halfeye() function, the next step would be extracting the posterior draws. Here we’ll do so with the brms::posterior_samples() function.

post1 <- posterior_samples(fit1)
post2 <- posterior_samples(fit2)
post3 <- posterior_samples(fit3)

Focusing on fit1, here’s how we’d plot the results for the intercept \(\beta_0\).

# this part is unnecessary; it just adjusts some theme defaults to my liking
theme_set(theme_gray() +
            theme(axis.text.y  = element_text(hjust = 0),
                  axis.ticks.y = element_blank(),
                  panel.grid   = element_blank()))

# plot!
post1 %>% 
  ggplot(aes(x = b_Intercept, y = 0)) +
  stat_halfeye() +
  scale_y_continuous(NULL, breaks = NULL)

But how might we get the posterior draws from all three fits into one plot? The answer is by somehow combining the posterior draws from each into one data frame. There are many ways to do this. Perhaps the simplest is with the bind_rows() function.

posts <-
  bind_rows(
    post1,
    post2,
    post3
  ) %>% 
  mutate(prior = str_c("normal(0, ", c(10, 1, 0.1), ")") %>% rep(., each = 4000))

head(posts)
##   b_Intercept     sigma      lp__         prior
## 1  0.06440413 0.9408454 -202.2537 normal(0, 10)
## 2  0.02603356 0.9416735 -202.1114 normal(0, 10)
## 3 -0.02122717 0.8967501 -202.0446 normal(0, 10)
## 4  0.02620046 0.9521795 -202.2594 normal(0, 10)
## 5  0.02620046 0.9521795 -202.2594 normal(0, 10)
## 6  0.08025366 0.9101939 -202.1808 normal(0, 10)

The bind_rows() function worked well, here, because all three post objects had the same number of columns of the same names. So we just stacked them three high. That is, we went from three data objects of 4,000 rows and 3 columns to one data object with 12,000 rows and 3 columns. But with the mutate() function we did add a fourth column, prior, that indexed which model each row came from. Now our data are ready, we can plot.

posts %>% 
  ggplot(aes(x = b_Intercept, y = prior)) +
  stat_halfeye()

Our plot arrangement made it easy to compare the results of tightening the prior on \(\beta_0\); the narrower the prior, the narrower the posterior.

What if my posterior_samples() aren’t of the same dimensions across models?

For the next examples, we need new data. Here we’ll simulate three predictors–x1, x2, and x3. We then simulate our criterion y as a linear additive function of those predictors.

set.seed(1)
d <-
  tibble(x1 = rnorm(n, mean = 0, sd = 1),
         x2 = rnorm(n, mean = 0, sd = 1),
         x3 = rnorm(n, mean = 0, sd = 1)) %>% 
  mutate(y  = rnorm(n, mean = 0 + x1 * 0 + x2 * 0.2 + x3 * -0.4))

head(d)
## # A tibble: 6 x 4
##       x1      x2     x3      y
##    <dbl>   <dbl>  <dbl>  <dbl>
## 1 -0.626  0.450   0.894  0.694
## 2  0.184 -0.0186 -1.05  -0.189
## 3 -0.836 -0.318   1.97  -1.61 
## 4  1.60  -0.929  -0.384 -1.59 
## 5  0.330 -1.49    1.65  -2.41 
## 6 -0.820 -1.08    1.51  -0.764

We are going to work with these data in two ways. For the first example, we’ll fit a series of univariable models following the same basic form, but each with a different predictor. For the second example, we’ll fit a series of multivariable models with various combinations of the predictors. Each requires its own approach.

Same form, different predictors.

This time we’re just using the brms default priors. As such, the models all follow the form

\[ \begin{align*} y_i & \sim \text{Normal} (\mu_i, \sigma) \\ \mu_i & = \beta_0 + \beta_n x_n\\ \beta_0 & \sim \text{Student-t}(3, 0, 10) \\ \sigma & \sim \text{Student-t}(3, 0, 10) \end{align*} \]

You may be wondering What about the prior for \(\beta_n\)? The brms defaults for those are improper flat priors. We define \(\beta_n x_n\) for the next three models as

  • fit4: \(\beta_1 x_1\),
  • fit5: \(\beta_2 x_2\), and
  • fit5: \(\beta_3 x_3\).

Let’s fit the models.

fit4 <-
  brm(data = d,
      family = gaussian,
      y ~ 1 + x1,
      seed = 1)

fit5 <-
  update(fit4,
         newdata = d,
         y ~ 1 + x2,
         seed = 1)

fit6 <-
  update(fit4,
         newdata = d,
         y ~ 1 + x3,
         seed = 1)

Like before, save the posterior draws for each as separate data frames.

post4 <- posterior_samples(fit4)
post5 <- posterior_samples(fit5)
post6 <- posterior_samples(fit6)

This time, our simple bind_rows() trick won’t work well.

bind_rows(
  post4,
  post5,
  post6
) %>% 
  head()
##   b_Intercept        b_x1    sigma      lp__ b_x2 b_x3
## 1 -0.26609646 -0.07795464 1.249694 -242.9716   NA   NA
## 2 -0.11933443 -0.03143494 1.251379 -240.4618   NA   NA
## 3 -0.10952301  0.02739295 1.278072 -241.2102   NA   NA
## 4 -0.08785528 -0.01065453 1.443157 -245.2715   NA   NA
## 5 -0.22020421 -0.16635358 1.185220 -241.7569   NA   NA
## 6  0.02973246 -0.13106488 1.123438 -239.2940   NA   NA

We don’t want separate columns for b_x1, b_x2, and b_x3. We want them all stacked atop one another. One simple solution is a two-step wherein we (1) select the relevant columns from each and bind them together with bind_cols() and then (2) stack them atop one another with the gather() function.

posts <-
  bind_cols(
    post4 %>% select(b_x1),
    post5 %>% select(b_x2),
    post6 %>% select(b_x3)
  ) %>% 
  gather() %>% 
  mutate(predictor = str_remove(key, "b_"))

head(posts)
##    key       value predictor
## 1 b_x1 -0.07795464        x1
## 2 b_x1 -0.03143494        x1
## 3 b_x1  0.02739295        x1
## 4 b_x1 -0.01065453        x1
## 5 b_x1 -0.16635358        x1
## 6 b_x1 -0.13106488        x1

That mutate() line at the end wasn’t necessary, but it will make the plot more attractive.

posts %>% 
  ggplot(aes(x = value, y = predictor)) +
  stat_halfeye()

Different combinations of predictors in different forms.

Now we fit a series of multivariable models. The first three will have combinations of two of the predictors. The final model will have all three. For simplicity, we continue to use the brms default priors.

fit7 <-
  brm(data = d,
      family = gaussian,
      y ~ 1 + x1 + x2,
      seed = 1)

fit8 <-
  update(fit7,
         newdata = d,
         y ~ 1 + x1 + x3,
         seed = 1)

fit9 <-
  update(fit7,
         newdata = d,
         y ~ 1 + x2 + x3,
         seed = 1)

fit10 <-
  update(fit7,
         newdata = d,
         y ~ 1 + x1 + x2 + x3,
         seed = 1)

Individually extract the posterior draws.

post7  <- posterior_samples(fit7)
post8  <- posterior_samples(fit8)
post9  <- posterior_samples(fit9)
post10 <- posterior_samples(fit10)

Take a look at what happens this time when we use the bind_rows() approach.

posts <-
  bind_rows(
    post7,
    post8,
    post9,
    post10
  ) 

glimpse(posts)
## Rows: 16,000
## Columns: 6
## $ b_Intercept <dbl> -0.034398871, 0.008116322, 0.109134954, -0.134114504, -0.148230448, 0.04629622…
## $ b_x1        <dbl> -0.018887709, -0.156024614, -0.248414749, 0.057442787, 0.241874229, -0.3504998…
## $ b_x2        <dbl> 0.23847261, 0.27500306, 0.37294396, 0.20640317, 0.15437136, 0.28201317, 0.1538…
## $ sigma       <dbl> 1.250134, 1.065501, 1.029253, 1.220301, 1.206074, 1.114755, 1.180636, 1.266597…
## $ lp__        <dbl> -236.9970, -236.7477, -241.3055, -237.9540, -242.0909, -239.3407, -237.2902, -…
## $ b_x3        <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA…

We still have the various data frames stacked atop another, with the data from post7 in the first 4,000 rows. See how the values in the b_x3 column are all missing (i.e., filled with NA values)? That’s because fit7 didn’t contain x3 as a predictor. Similarly, if we were to look at rows 4,001 through 8,000, we’d see column b_x2 would be the one filled with NAs. This behavior is a good thing, here. After a little more wrangling, we’ll plot and it should be become clear why. Here’s the wrangling.

posts <-
  posts %>% 
  select(starts_with("b_x")) %>% 
  mutate(contains = rep(c("<1, 1, 0>", "<1, 0, 1>", "<0, 1, 1>", "<1, 1, 1>"), each = 4000)) %>% 
  gather(key, value, -contains) %>% 
  mutate(coefficient = str_remove(key, "b_x") %>% str_c("beta[", ., "]"))

head(posts)
##    contains  key       value coefficient
## 1 <1, 1, 0> b_x1 -0.01888771     beta[1]
## 2 <1, 1, 0> b_x1 -0.15602461     beta[1]
## 3 <1, 1, 0> b_x1 -0.24841475     beta[1]
## 4 <1, 1, 0> b_x1  0.05744279     beta[1]
## 5 <1, 1, 0> b_x1  0.24187423     beta[1]
## 6 <1, 1, 0> b_x1 -0.35049990     beta[1]

With the contains variable, we indexed which fit the draws came from. The 1’s and 0’s within the angle brackets indicate which of the three predictors were present within the model with the 1’s indicating they were and the 0’s indicating they were not. For example, <1, 1, 0> in the first row indicated this was the model including x1 and x2. Importantly, we also added a coefficient index. This is just a variant of key that’ll make the strip labels in our plot more attractive. Behold:

posts %>% 
  drop_na(value) %>% 
  ggplot(aes(x = value, y = contains)) +
  stat_halfeye() +
  ylab(NULL) +
  facet_wrap(~coefficient, ncol = 1, labeller = label_parsed)

Hopefully now it’s clear why it was good to save those cells with the NAs.

Bonus: You can streamline your workflow.

The workflows above are generally fine. But they’re a little inefficient. If you’d like to reduce the amount of code you’re writing and the number of objects you have floating around in your environment, you might consider a more streamlined workflow where you work with your fit objects in bulk. Here we’ll demonstrate a nested tibble approach with the first three fits.

posts <-
  tibble(name  = str_c("fit", 1:3),
         prior = str_c("normal(0, ", c(10, 1, 0.1), ")")) %>% 
  mutate(fit = map(name, get)) %>% 
  mutate(post = map(fit, posterior_samples))
  
head(posts)
## # A tibble: 3 x 4
##   name  prior          fit       post                
##   <chr> <chr>          <list>    <list>              
## 1 fit1  normal(0, 10)  <brmsfit> <df[,3] [4,000 × 3]>
## 2 fit2  normal(0, 1)   <brmsfit> <df[,3] [4,000 × 3]>
## 3 fit3  normal(0, 0.1) <brmsfit> <df[,3] [4,000 × 3]>

We have a 3-row nested tibble. The first column, name is just a character vector with the names of the fits. The next column isn’t necessary, but it nicely explicates the main difference in the models: the prior we used on the intercept. It’s in the map() functions within the two mutate()lines where all the magic happens. With the first, we used the get() function to snatch up the brms fit objects matching the names in the name column. In the second, we used the posterior_samples() function to extract the posterior draws from each of the fits saved in fit. Do you see how each for in the post column contains an entire \(4,000 \times 3\) data frame? That’s why we refer to this as a nested tibble. We have data frames compressed within data frames. If you’d like to access the data within the post column, just unnest().

posts %>% 
  select(-fit) %>% 
  unnest(post)
## # A tibble: 12,000 x 5
##    name  prior         b_Intercept sigma  lp__
##    <chr> <chr>               <dbl> <dbl> <dbl>
##  1 fit1  normal(0, 10)     0.0644  0.941 -202.
##  2 fit1  normal(0, 10)     0.0260  0.942 -202.
##  3 fit1  normal(0, 10)    -0.0212  0.897 -202.
##  4 fit1  normal(0, 10)     0.0262  0.952 -202.
##  5 fit1  normal(0, 10)     0.0262  0.952 -202.
##  6 fit1  normal(0, 10)     0.0803  0.910 -202.
##  7 fit1  normal(0, 10)    -0.00142 0.886 -202.
##  8 fit1  normal(0, 10)     0.0696  0.939 -202.
##  9 fit1  normal(0, 10)    -0.172   0.943 -205.
## 10 fit1  normal(0, 10)     0.0259  0.839 -203.
## # … with 11,990 more rows

After un-nesting, we can remake the plot from above.

posts %>% 
  select(-fit) %>% 
  unnest(post) %>% 

  ggplot(aes(x = b_Intercept, y = prior)) +
  stat_halfeye()

To learn more about using the tidyverse for iterating and saving the results in nested tibbles, check out Hadley Wickham’s great talk, Managing many models.

Session information

sessionInfo()
## R version 4.0.4 (2021-02-15)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Catalina 10.15.7
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] tidybayes_2.3.1 brms_2.15.0     Rcpp_1.0.6      forcats_0.5.1   stringr_1.4.0   dplyr_1.0.5    
##  [7] purrr_0.3.4     readr_1.4.0     tidyr_1.1.3     tibble_3.1.0    ggplot2_3.3.3   tidyverse_1.3.0
## 
## loaded via a namespace (and not attached):
##   [1] readxl_1.3.1         backports_1.2.1      plyr_1.8.6           igraph_1.2.6        
##   [5] splines_4.0.4        svUnit_1.0.3         crosstalk_1.1.0.1    TH.data_1.0-10      
##   [9] rstantools_2.1.1     inline_0.3.17        digest_0.6.27        htmltools_0.5.1.1   
##  [13] rsconnect_0.8.16     fansi_0.4.2          magrittr_2.0.1       modelr_0.1.8        
##  [17] RcppParallel_5.0.2   matrixStats_0.57.0   xts_0.12.1           sandwich_3.0-0      
##  [21] prettyunits_1.1.1    colorspace_2.0-0     rvest_0.3.6          ggdist_2.4.0.9000   
##  [25] haven_2.3.1          xfun_0.22            callr_3.5.1          crayon_1.4.1        
##  [29] jsonlite_1.7.2       lme4_1.1-25          survival_3.2-10      zoo_1.8-8           
##  [33] glue_1.4.2           gtable_0.3.0         emmeans_1.5.2-1      V8_3.4.0            
##  [37] distributional_0.2.2 pkgbuild_1.2.0       rstan_2.21.2         abind_1.4-5         
##  [41] scales_1.1.1         mvtnorm_1.1-1        DBI_1.1.0            miniUI_0.1.1.1      
##  [45] xtable_1.8-4         stats4_4.0.4         StanHeaders_2.21.0-7 DT_0.16             
##  [49] htmlwidgets_1.5.2    httr_1.4.2           threejs_0.3.3        arrayhelpers_1.1-0  
##  [53] ellipsis_0.3.1       farver_2.0.3         pkgconfig_2.0.3      loo_2.4.1           
##  [57] dbplyr_2.0.0         utf8_1.1.4           labeling_0.4.2       tidyselect_1.1.0    
##  [61] rlang_0.4.10         reshape2_1.4.4       later_1.1.0.1        munsell_0.5.0       
##  [65] cellranger_1.1.0     tools_4.0.4          cli_2.3.1            generics_0.1.0      
##  [69] broom_0.7.5          ggridges_0.5.2       evaluate_0.14        fastmap_1.0.1       
##  [73] yaml_2.2.1           processx_3.4.5       knitr_1.31           fs_1.5.0            
##  [77] nlme_3.1-152         mime_0.10            projpred_2.0.2       xml2_1.3.2          
##  [81] compiler_4.0.4       bayesplot_1.8.0      shinythemes_1.1.2    rstudioapi_0.13     
##  [85] gamm4_0.2-6          curl_4.3             reprex_0.3.0         statmod_1.4.35      
##  [89] stringi_1.5.3        highr_0.8            ps_1.6.0             blogdown_1.3        
##  [93] Brobdingnag_1.2-6    lattice_0.20-41      Matrix_1.3-2         nloptr_1.2.2.2      
##  [97] markdown_1.1         shinyjs_2.0.0        vctrs_0.3.6          pillar_1.5.1        
## [101] lifecycle_1.0.0      bridgesampling_1.0-0 estimability_1.3     httpuv_1.5.4        
## [105] R6_2.5.0             bookdown_0.21        promises_1.1.1       gridExtra_2.3       
## [109] codetools_0.2-18     boot_1.3-26          colourpicker_1.1.0   MASS_7.3-53         
## [113] gtools_3.8.2         assertthat_0.2.1     withr_2.4.1          shinystan_2.5.0     
## [117] multcomp_1.4-16      mgcv_1.8-33          parallel_4.0.4       hms_0.5.3           
## [121] grid_4.0.4           coda_0.19-4          minqa_1.2.4          rmarkdown_2.7       
## [125] shiny_1.5.0          lubridate_1.7.9.2    base64enc_0.1-3      dygraphs_1.1.1.6

Related