Skip to content

Commit bf397e4

Browse files
committed
Add handling for variable order, fix chains return in draws
1 parent d34b77e commit bf397e4

3 files changed

Lines changed: 15 additions & 10 deletions

File tree

R/fit.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ unconstrain_variables <- function(variables) {
518518
" not provided!", call. = FALSE)
519519
}
520520

521-
variables_vector <- unlist(variables, recursive = TRUE, use.names = FALSE)
521+
variables_vector <- unlist(variables[model_par_names], recursive = TRUE)
522522
private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, variables_vector)
523523
}
524524
CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_variables)
@@ -598,7 +598,7 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
598598
unconstrained <- private$model_methods_env_$unconstrain_draws(private$model_methods_env_$model_ptr_, draws)
599599
uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE)
600600
names(unconstrained) <- repair_variable_names(uncon_names)
601-
maybe_convert_draws_format(unconstrained, format)
601+
maybe_convert_draws_format(unconstrained, format, .nchains = posterior::nchains(draws))
602602
}
603603
CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)
604604

R/model.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,10 +465,10 @@ compile <- function(quiet = TRUE,
465465
stanc_options = list(),
466466
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE),
467467
compile_model_methods = FALSE,
468-
compile_hessian_method = FALSE,
469468
compile_standalone = FALSE,
470469
dry_run = FALSE,
471470
#deprecated
471+
compile_hessian_method = FALSE,
472472
threads = FALSE) {
473473

474474
if (length(self$stan_file()) == 0) {
@@ -505,6 +505,11 @@ compile <- function(quiet = TRUE,
505505
cpp_options[["stan_threads"]] <- TRUE
506506
}
507507

508+
# temporary deprecation warnings
509+
if (isTRUE(compile_hessian_method)) {
510+
warning("'compile_hessian_method' is deprecated. The hessian method is compiled with all models.")
511+
}
512+
508513
if (length(self$exe_file()) == 0) {
509514
if (is.null(dir)) {
510515
exe_base <- self$stan_file()

R/utils.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -409,19 +409,19 @@ valid_draws_formats <- function() {
409409
"draws_rvars", "rvars")
410410
}
411411

412-
maybe_convert_draws_format <- function(draws, format) {
412+
maybe_convert_draws_format <- function(draws, format, ...) {
413413
if (is.null(draws)) {
414414
return(draws)
415415
}
416416
format <- sub("^draws_", "", format)
417417
switch(
418418
format,
419-
"array" = posterior::as_draws_array(draws),
420-
"df" = posterior::as_draws_df(draws),
421-
"data.frame" = posterior::as_draws_df(draws),
422-
"list" = posterior::as_draws_list(draws),
423-
"matrix" = posterior::as_draws_matrix(draws),
424-
"rvars" = posterior::as_draws_rvars(draws),
419+
"array" = posterior::as_draws_array(draws, ...),
420+
"df" = posterior::as_draws_df(draws, ...),
421+
"data.frame" = posterior::as_draws_df(draws, ...),
422+
"list" = posterior::as_draws_list(draws, ...),
423+
"matrix" = posterior::as_draws_matrix(draws, ...),
424+
"rvars" = posterior::as_draws_rvars(draws, ...),
425425
stop("Invalid draws format.", call. = FALSE)
426426
)
427427
}

0 commit comments

Comments
 (0)