Skip to content

Commit 499aa23

Browse files
authored
Fix spurious cmdstan config errors (#981)
1 parent 91e4bf6 commit 499aa23

9 files changed

Lines changed: 35 additions & 77 deletions

R/model.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,8 +1167,8 @@ sample <- function(data = NULL,
11671167
show_messages = TRUE,
11681168
show_exceptions = TRUE,
11691169
diagnostics = c("divergences", "treedepth", "ebfmi"),
1170-
save_metric = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
1171-
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
1170+
save_metric = NULL,
1171+
save_cmdstan_config = NULL,
11721172
# deprecated
11731173
cores = NULL,
11741174
num_cores = NULL,
@@ -1379,7 +1379,7 @@ sample_mpi <- function(data = NULL,
13791379
show_messages = TRUE,
13801380
show_exceptions = TRUE,
13811381
diagnostics = c("divergences", "treedepth", "ebfmi"),
1382-
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
1382+
save_cmdstan_config = NULL,
13831383
# deprecated
13841384
validate_csv = TRUE) {
13851385

@@ -1525,7 +1525,7 @@ optimize <- function(data = NULL,
15251525
history_size = NULL,
15261526
show_messages = TRUE,
15271527
show_exceptions = TRUE,
1528-
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
1528+
save_cmdstan_config = NULL) {
15291529
procs <- CmdStanProcs$new(
15301530
num_procs = 1,
15311531
show_stderr_messages = show_exceptions,
@@ -1659,7 +1659,7 @@ laplace <- function(data = NULL,
16591659
draws = NULL,
16601660
show_messages = TRUE,
16611661
show_exceptions = TRUE,
1662-
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
1662+
save_cmdstan_config = NULL) {
16631663
if (cmdstan_version() < "2.32") {
16641664
stop("This method is only available in cmdstan >= 2.32", call. = FALSE)
16651665
}
@@ -1815,7 +1815,7 @@ variational <- function(data = NULL,
18151815
draws = NULL,
18161816
show_messages = TRUE,
18171817
show_exceptions = TRUE,
1818-
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
1818+
save_cmdstan_config = NULL) {
18191819
procs <- CmdStanProcs$new(
18201820
num_procs = 1,
18211821
show_stderr_messages = show_exceptions,
@@ -1960,7 +1960,7 @@ pathfinder <- function(data = NULL,
19601960
calculate_lp = NULL,
19611961
show_messages = TRUE,
19621962
show_exceptions = TRUE,
1963-
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
1963+
save_cmdstan_config = NULL) {
19641964
procs <- CmdStanProcs$new(
19651965
num_procs = 1,
19661966
show_stderr_messages = show_exceptions,

R/run.R

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ CmdStanRun <- R6::R6Class(
2626
if (cmdstan_version() >= "2.26.0") {
2727
private$profile_files_ <- self$new_profile_files()
2828
}
29-
if (cmdstan_version() >= "2.34.0" && !is.null(self$args$save_cmdstan_config) && self$args$save_cmdstan_config) {
29+
if (cmdstan_version() >= "2.34.0" && !is.null(self$args$save_cmdstan_config) && as.logical(self$args$save_cmdstan_config)) {
3030
private$config_files_ <- self$new_config_files()
3131
}
32-
if (cmdstan_version() >= "2.34.0" && !is.null(self$args$method_args$save_metric) && self$args$method_args$save_metric) {
32+
if (cmdstan_version() >= "2.34.0" && !is.null(self$args$method_args$save_metric) && as.logical(self$args$method_args$save_metric)) {
3333
private$metric_files_ <- self$new_metric_files()
3434
}
3535
if (self$args$save_latent_dynamics) {
@@ -77,13 +77,6 @@ CmdStanRun <- R6::R6Class(
7777
config_files = function(include_failed = FALSE) {
7878
files <- private$config_files_
7979
files_win_path <- sapply(private$config_files_, wsl_safe_path, revert = TRUE)
80-
if (!length(files) || !any(file.exists(files_win_path))) {
81-
stop(
82-
"No CmdStan config files found. ",
83-
"Set 'save_cmdstan_config=TRUE' when fitting the model.",
84-
call. = FALSE
85-
)
86-
}
8780
if (include_failed) {
8881
files
8982
} else {
@@ -94,13 +87,6 @@ CmdStanRun <- R6::R6Class(
9487
metric_files = function(include_failed = FALSE) {
9588
files <- private$metric_files_
9689
files_win_path <- sapply(private$metric_files_, wsl_safe_path, revert = TRUE)
97-
if (!length(files) || !any(file.exists(files_win_path))) {
98-
stop(
99-
"No metric files found. ",
100-
"Set 'save_metric=TRUE' when fitting the model.",
101-
call. = FALSE
102-
)
103-
}
10490
if (include_failed) {
10591
files
10692
} else {
@@ -404,12 +390,12 @@ CmdStanRun <- R6::R6Class(
404390
private$profile_files_,
405391
if (cmdstan_version() > "2.34.0" &&
406392
!is.null(self$args$save_cmdstan_config) &&
407-
self$args$save_cmdstan_config &&
393+
as.logical(self$args$save_cmdstan_config) &&
408394
!private$config_files_saved_)
409395
self$config_files(include_failed = TRUE),
410396
if (cmdstan_version() > "2.34.0" &&
411397
!(is.null(self$args$method_args$save_metric)) &&
412-
self$args$method_args$save_metric &&
398+
as.logical(self$args$method_args$save_metric) &&
413399
!private$metric_files_saved_)
414400
self$metric_files(include_failed = TRUE)
415401
)

man/model-method-laplace.Rd

Lines changed: 1 addition & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/model-method-optimize.Rd

Lines changed: 1 addition & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/model-method-pathfinder.Rd

Lines changed: 1 addition & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/model-method-sample.Rd

Lines changed: 2 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/model-method-sample_mpi.Rd

Lines changed: 1 addition & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/model-method-variational.Rd

Lines changed: 1 addition & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-model-output_dir.R

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,23 @@ test_that("all fitting methods work with output_dir", {
2727
files <- list.files(method_dir)
2828
}
2929
# specifying output_dir
30-
fit <- testing_fit("bernoulli", method = method, seed = 123,
31-
output_dir = method_dir)
30+
call_args <- list(
31+
"bernoulli",
32+
method = method,
33+
seed = 123,
34+
output_dir = method_dir,
35+
save_cmdstan_config = TRUE
36+
)
37+
if (method == "sample") {
38+
call_args$save_metric <- TRUE
39+
}
40+
fit <- do.call(testing_fit, call_args)
3241
# WSL path manipulations result in a short path which slightly differs
3342
# from the original tempdir(), so need to normalise both for comparison
3443
expect_equal(normalizePath(fit$runset$args$output_dir),
3544
normalizePath(method_dir))
3645
files <- normalizePath(list.files(method_dir, full.names = TRUE))
37-
# in 2.34.0 we also save the config files for all methods and the metric
38-
# for sample
39-
if (cmdstan_version() < "2.34.0") {
40-
mult <- 1
41-
} else if (method == "sample") {
46+
if (method == "sample") {
4247
mult <- 3
4348
expect_equal(files[grepl("metric", files)],
4449
normalizePath(sapply(fit$metric_files(), wsl_safe_path, revert = TRUE,
@@ -99,7 +104,10 @@ test_that("error if output_dir is invalid", {
99104
})
100105

101106
test_that("output_dir works with trailing /", {
102-
test_dir <- file.path(sandbox, "trailing")
107+
test_dir <- file.path(tempdir(check = TRUE), "output_dir")
108+
if (dir.exists(test_dir)) {
109+
unlink(test_dir, recursive = TRUE)
110+
}
103111
dir.create(test_dir)
104112
fit <- testing_fit(
105113
"bernoulli",
@@ -109,7 +117,5 @@ test_that("output_dir works with trailing /", {
109117
)
110118
expect_equal(normalizePath(fit$runset$args$output_dir),
111119
normalizePath(test_dir))
112-
# in 2.34.0 we also save the metric and config files
113-
mult <- if (cmdstan_version() >= "2.34.0") 3 else 1
114-
expect_equal(length(list.files(test_dir)), mult * fit$num_procs())
120+
expect_equal(length(list.files(test_dir)), fit$num_procs())
115121
})

0 commit comments

Comments
 (0)