The introductory vignette vignette caters to Bayesian data analysis workflows with few datasets to analyze. However, it is sometimes desirable to run one or more Bayesian models repeatedly across many simulated datasets. Examples:
- Validate the implementation of a Bayesian model, using simulation to determine how reliably the model estimates the parameters under known data-generating scenarios.
- Simulate a randomized controlled experiment to explore frequentist properties such as power and Type I error.
This vignette focuses on (1). The goal of this particular example to simulate multiple datasets from the model below, analyze each dataset, and assess how often the estimated posterior intervals cover the true parameters from the prior predictive simulations. The quantile method by Cook, Gelman, and Rubin (2006) generalizes this concept, and simulation-based calibration (Talts et al. 2020) generalizes further. The interval-based technique featured in this vignette is not as robust as SBC, but it may be more expedient for large models because it does not require visual inspection of multiple histograms.
Consider a simple regression model with a continuous response
y
with a covariate x
.
\[ \begin{aligned} y_i &\stackrel{\text{iid}}{\sim} \text{Normal}(\beta_1 + x_i \beta_2, 1) \\ \beta_1, \beta_2 &\stackrel{\text{iid}}{\sim} \text{Normal}(0, 1) \end{aligned} \]
We write this model in a JAGS model file.
lines <- "model {
for (i in 1:n) {
y[i] ~ dnorm(beta[1] + x[i] * beta[2], 1)
}
for (i in 1:2) {
beta[i] ~ dnorm(0, 1)
}
}"
writeLines(lines, "model.jags")
Next, we define a pipeline to simulate multiple datasets and fit each
dataset with the model. In our data-generating function, we put the true
parameter values of each simulation in a special .join_data
list. jagstargets
will automatically join the elements of
.join_data
to the correspondingly named variables in the
summary output. This will make it super easy to check how often our
posterior intervals capture the truth. As for scale, generate 20
datasets (5 batches with 4 replications each) and run the model on each
of the 20 datasets.1 By default, each of the 20 model runs
computes 3 MCMC chains with 2000 MCMC iterations each (including
burn-in) and you can adjust with the n.chains
and
n.iter
arguments of
tar_jags_rep_summary()
.
# _targets.R
library(targets)
library(jagstargets)
options(crayon.enabled = FALSE)
# Use computer memory more sparingly:
tar_option_set(memory = "transient", garbage_collection = TRUE)
generate_data <- function(n = 10L) {
beta <- stats::rnorm(n = 2, mean = 0, sd = 1)
x <- seq(from = -1, to = 1, length.out = n)
y <- stats::rnorm(n, beta[1] + x * beta[2], 1)
# Elements of .join_data get joined on to the .join_data column
# in the summary output next to the model parameters
# with the same names.
.join_data <- list(beta = beta)
list(n = n, x = x, y = y, .join_data = .join_data)
}
list(
tar_jags_rep_summary(
model,
"model.jags",
data = generate_data(),
parameters.to.save = "beta",
batches = 5, # Number of branch targets.
reps = 4, # Number of model reps per branch target.
variables = "beta",
summaries = list(
~posterior::quantile2(.x, probs = c(0.025, 0.975))
)
)
)
We now have a pipeline that runs the model 10 times: 5 batches (branch targets) with 4 replications per batch.
tar_visnetwork()
#> Error in get(paste0(generic, ".", class), envir = get_method_env()) :
#> object 'type_sum.accel' not found
#> Warning message:
#> package ‘targets’ was built under R version 4.4.2
#>
Run the computation with tar_make()
tar_make()
#> Error in get(paste0(generic, ".", class), envir = get_method_env()) :
#> object 'type_sum.accel' not found
#> ▶ dispatched target model_batch
#> ● completed target model_batch [0 seconds, 99 bytes]
#> ▶ dispatched target model_file_model
#> ● completed target model_file_model [0 seconds, 128 bytes]
#> ▶ dispatched branch model_data_5fcdec5f855f2d9c
#> ● completed branch model_data_5fcdec5f855f2d9c [0.006 seconds, 726 bytes]
#> ▶ dispatched branch model_data_b6c9a18333c6a8ca
#> ● completed branch model_data_b6c9a18333c6a8ca [0.001 seconds, 725 bytes]
#> ▶ dispatched branch model_data_5db4354944466148
#> ● completed branch model_data_5db4354944466148 [0.001 seconds, 729 bytes]
#> ▶ dispatched branch model_data_4a40cb783277d5dc
#> ● completed branch model_data_4a40cb783277d5dc [0.001 seconds, 727 bytes]
#> ▶ dispatched branch model_data_104af6d505e730d6
#> ● completed branch model_data_104af6d505e730d6 [0.001 seconds, 729 bytes]
#> ● completed pattern model_data
#> ▶ dispatched target model_lines_model
#> ● completed target model_lines_model [0 seconds, 144 bytes]
#> ▶ dispatched branch model_model_50b3d9bcb9189fef
#> ● completed branch model_model_50b3d9bcb9189fef [0.202 seconds, 1.629 kilobytes]
#> ▶ dispatched branch model_model_93bc2c2a4b8dc29f
#> ● completed branch model_model_93bc2c2a4b8dc29f [0.169 seconds, 1.629 kilobytes]
#> ▶ dispatched branch model_model_e2ab729f4fa1dd45
#> ● completed branch model_model_e2ab729f4fa1dd45 [0.17 seconds, 1.629 kilobytes]
#> ▶ dispatched branch model_model_5871bb9227fbbf93
#> ● completed branch model_model_5871bb9227fbbf93 [0.209 seconds, 1.629 kilobytes]
#> ▶ dispatched branch model_model_820c742ab2ba1134
#> ● completed branch model_model_820c742ab2ba1134 [0.169 seconds, 1.629 kilobytes]
#> ● completed pattern model_model
#> ▶ dispatched target model
#> ● completed target model [0 seconds, 4.508 kilobytes]
#> ▶ ended pipeline [2.52 seconds]
#> Warning message:
#> package ‘targets’ was built under R version 4.4.2
#>
The result is an aggregated data frame of summary statistics, where
the .rep
column distinguishes among individual replicates.
We have the posterior intervals for beta
in columns
q2.5
and q97.5
. And thanks to the
.join_data
list we included in
generate_data()
, our output has a .join_data
column with the true values of the parameters in our simulations.
tar_load(model)
model
#> # A tibble: 40 × 9
#> variable q2.5 q97.5 .join_data .dataset_id .rep .seed .file .name
#> <chr> <dbl> <dbl> <dbl> <chr> <chr> <int> <chr> <chr>
#> 1 beta[1] -2.02 -0.837 -1.24 model_data_5fcd… fc26… -5.71e8 mode… model
#> 2 beta[2] -0.373 1.37 1.02 model_data_5fcd… fc26… -5.71e8 mode… model
#> 3 beta[1] 1.52 2.68 2.18 model_data_5fcd… b762… 1.03e9 mode… model
#> 4 beta[2] -0.824 0.913 0.315 model_data_5fcd… b762… 1.03e9 mode… model
#> 5 beta[1] 0.543 1.75 0.892 model_data_5fcd… 79b6… 1.92e9 mode… model
#> 6 beta[2] 1.40 3.19 1.98 model_data_5fcd… 79b6… 1.92e9 mode… model
#> 7 beta[1] 1.37 2.52 1.33 model_data_5fcd… e7a3… 1.95e9 mode… model
#> 8 beta[2] -0.914 0.762 -0.454 model_data_5fcd… e7a3… 1.95e9 mode… model
#> 9 beta[1] -0.0754 1.12 0.0642 model_data_b6c9… 969f… 7.78e8 mode… model
#> 10 beta[2] 0.0363 1.74 1.12 model_data_b6c9… 969f… 7.78e8 mode… model
#> # ℹ 30 more rows
Now, let’s assess how often the estimated 95% posterior intervals
capture the true values of beta
. If the model is
implemented correctly, the coverage value below should be close to 95%.
(Ordinarily, we would increase
the number of batches and reps per batch and run batches in
parallel computing.)
library(dplyr)
model %>%
group_by(variable) %>%
dplyr::summarize(coverage = mean(q2.5 < .join_data & .join_data < q97.5))
#> # A tibble: 2 × 2
#> variable coverage
#> <chr> <dbl>
#> 1 beta[1] 0.95
#> 2 beta[2] 0.95
For maximum reproducibility, we should express the coverage assessment as a custom function and a target in the pipeline.
# _targets.R
# packages needed to define the pipeline:
library(targets)
library(jagstargets)
tar_option_set(
packages = "dplyr", # packages needed to run the pipeline
memory = "transient", # memory efficiency
garbage_collection = TRUE # memory efficiency
)
generate_data <- function(n = 10L) {
beta <- stats::rnorm(n = 2, mean = 0, sd = 1)
x <- seq(from = -1, to = 1, length.out = n)
y <- stats::rnorm(n, beta[1] + x * beta[2], 1)
# Elements of .join_data get joined on to the .join_data column
# in the summary output next to the model parameters
# with the same names.
.join_data <- list(beta = beta)
list(n = n, x = x, y = y, .join_data = .join_data)
}
list(
tar_jags_rep_summary(
model,
"model.jags",
data = generate_data(),
parameters.to.save = "beta",
batches = 5, # Number of branch targets.
reps = 4, # Number of model reps per branch target.
variables = "beta",
summaries = list(
~posterior::quantile2(.x, probs = c(0.025, 0.975))
)
),
tar_target(
coverage,
model %>%
group_by(variable) %>%
summarize(
coverage = mean(q2.5 < .join_data & .join_data < q97.5),
.groups = "drop"
)
)
)
The new coverage
target should the only outdated target,
and it should be connected to the upstream model
target.
tar_visnetwork()
#> Error in get(paste0(generic, ".", class), envir = get_method_env()) :
#> object 'type_sum.accel' not found
#> Warning message:
#> package ‘targets’ was built under R version 4.4.2
#>
When we run the pipeline, only the coverage assessment should run. That way, we skip all the expensive computation of simulating datasets and running MCMC multiple times.
tar_make()
#> Error in get(paste0(generic, ".", class), envir = get_method_env()) :
#> object 'type_sum.accel' not found
#> ✔ skipped target model_batch
#> ✔ skipped target model_file_model
#> ✔ skipped branch model_data_5fcdec5f855f2d9c
#> ✔ skipped branch model_data_b6c9a18333c6a8ca
#> ✔ skipped branch model_data_5db4354944466148
#> ✔ skipped branch model_data_4a40cb783277d5dc
#> ✔ skipped branch model_data_104af6d505e730d6
#> ✔ skipped pattern model_data
#> ✔ skipped target model_lines_model
#> ✔ skipped branch model_model_50b3d9bcb9189fef
#> ✔ skipped branch model_model_93bc2c2a4b8dc29f
#> ✔ skipped branch model_model_e2ab729f4fa1dd45
#> ✔ skipped branch model_model_5871bb9227fbbf93
#> ✔ skipped branch model_model_820c742ab2ba1134
#> ✔ skipped pattern model_model
#> ✔ skipped target model
#> ▶ dispatched target coverage
#> ● completed target coverage [0.015 seconds, 173 bytes]
#> ▶ ended pipeline [0.526 seconds]
#> Warning message:
#> package ‘targets’ was built under R version 4.4.2
#>
tar_read(coverage)
#> # A tibble: 2 × 2
#> variable coverage
#> <chr> <dbl>
#> 1 beta[1] 0.95
#> 2 beta[2] 0.95
Multiple models
tar_jags_rep_mcmc_summary()
and similar functions allow
you to supply multiple jags models. If you do, each model will share the
the same collection of datasets, and the .dataset_id
column
of the model target output allows for custom analyses that compare
different models against each other. Below, we add a new
model2.jags
file to the jags_files
argument of
tar_jags_rep_mcmc_summary()
. In the coverage summary below,
we group by .name
to compute a coverage statistic for each
model.
lines <- "model {
for (i in 1:n) {
y[i] ~ dnorm(beta[1] + x[i] * x[i] * beta[2], 1) # Regress on x^2, not x.
}
for (i in 1:2) {
beta[i] ~ dnorm(0, 1)
}
}"
writeLines(lines, "model2.jags")
# _targets.R
# packages needed to define the pipeline:
library(targets)
library(jagstargets)
tar_option_set(
packages = "dplyr", # packages needed to run the pipeline
memory = "transient", # memory efficiency
garbage_collection = TRUE # memory efficiency
)
generate_data <- function(n = 10L) {
beta <- stats::rnorm(n = 2, mean = 0, sd = 1)
x <- seq(from = -1, to = 1, length.out = n)
y <- stats::rnorm(n, beta[1] + x * beta[2], 1)
# Elements of .join_data get joined on to the .join_data column
# in the summary output next to the model parameters
# with the same names.
.join_data <- list(beta = beta)
list(n = n, x = x, y = y, .join_data = .join_data)
}
list(
tar_jags_rep_summary(
model,
c("model.jags", "model2.jags"), # another model
data = generate_data(),
parameters.to.save = "beta",
batches = 5,
reps = 4,
variables = "beta",
summaries = list(
~posterior::quantile2(.x, probs = c(0.025, 0.975))
)
),
tar_target(
coverage,
model %>%
group_by(.name) %>%
summarize(coverage = mean(q2.5 < .join_data & .join_data < q97.5))
)
)
In the graph below, notice how targets model_model1
and
model_model2
are both connected to model_data
upstream. Downstream, model
is equivalent to
dplyr::bind_rows(model_model1, model_model2)
, and it will
have special columns .name
and .file
to
distinguish among all the models.
tar_visnetwork()
#> Error in get(paste0(generic, ".", class), envir = get_method_env()) :
#> object 'type_sum.accel' not found
#> Warning message:
#> package ‘targets’ was built under R version 4.4.2
#>