Skip to content

Commit a79cc5e

Browse files
authored
Automatically init model methods, add inc_warmup argument to unconstrain_draws() method (#985)
* Automatically init model methods when not compiled * Automatically init model methods, add inc_warmup argument to unconstrain_draws
1 parent 1679aa7 commit a79cc5e

10 files changed

Lines changed: 64 additions & 105 deletions

R/fit.R

Lines changed: 31 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -326,13 +326,12 @@ CmdStanFit$set("public", name = "init", value = init)
326326
#' @examples
327327
#' \dontrun{
328328
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
329-
#' fit_mcmc$init_model_methods()
330329
#' }
331330
#' @seealso [log_prob()], [grad_log_prob()], [constrain_variables()],
332331
#' [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
333332
#' [hessian()]
334333
#'
335-
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
334+
init_model_methods <- function(seed = 1, verbose = FALSE, hessian = FALSE) {
336335
if (os_is_wsl()) {
337336
stop("Additional model methods are not currently available with ",
338337
"WSL CmdStan and will not be compiled",
@@ -348,11 +347,12 @@ init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
348347
"which is still experimental. Please report any compilation ",
349348
"errors that you encounter")
350349
}
351-
message("Compiling additional model methods...")
352350
if (is.null(private$model_methods_env_$model_ptr)) {
353351
expose_model_methods(private$model_methods_env_, verbose, hessian)
354352
}
355-
initialize_model_pointer(private$model_methods_env_, self$data_file(), seed)
353+
if (!("model_ptr_" %in% ls(private$model_methods_env_))) {
354+
initialize_model_pointer(private$model_methods_env_, self$data_file(), seed)
355+
}
356356
invisible(NULL)
357357
}
358358
CmdStanFit$set("public", name = "init_model_methods", value = init_model_methods)
@@ -372,7 +372,6 @@ CmdStanFit$set("public", name = "init_model_methods", value = init_model_methods
372372
#' @examples
373373
#' \dontrun{
374374
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
375-
#' fit_mcmc$init_model_methods()
376375
#' fit_mcmc$log_prob(unconstrained_variables = c(0.5, 1.2, 1.1, 2.2))
377376
#' }
378377
#'
@@ -385,10 +384,7 @@ log_prob <- function(unconstrained_variables, jacobian = TRUE, jacobian_adjustme
385384
warning("'jacobian_adjustment' is deprecated. Please use 'jacobian' instead.", call. = FALSE)
386385
jacobian <- jacobian_adjustment
387386
}
388-
if (is.null(private$model_methods_env_$model_ptr)) {
389-
stop("The method has not been compiled, please call `init_model_methods()` first",
390-
call. = FALSE)
391-
}
387+
self$init_model_methods()
392388
if (length(unconstrained_variables) != private$model_methods_env_$num_upars_) {
393389
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
394390
length(unconstrained_variables), " were provided!", call. = FALSE)
@@ -410,7 +406,6 @@ CmdStanFit$set("public", name = "log_prob", value = log_prob)
410406
#' @examples
411407
#' \dontrun{
412408
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
413-
#' fit_mcmc$init_model_methods()
414409
#' fit_mcmc$grad_log_prob(unconstrained_variables = c(0.5, 1.2, 1.1, 2.2))
415410
#' }
416411
#'
@@ -423,10 +418,7 @@ grad_log_prob <- function(unconstrained_variables, jacobian = TRUE, jacobian_adj
423418
warning("'jacobian_adjustment' is deprecated. Please use 'jacobian' instead.", call. = FALSE)
424419
jacobian <- jacobian_adjustment
425420
}
426-
if (is.null(private$model_methods_env_$model_ptr)) {
427-
stop("The method has not been compiled, please call `init_model_methods()` first",
428-
call. = FALSE)
429-
}
421+
self$init_model_methods()
430422
if (length(unconstrained_variables) != private$model_methods_env_$num_upars_) {
431423
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
432424
length(unconstrained_variables), " were provided!", call. = FALSE)
@@ -461,10 +453,7 @@ hessian <- function(unconstrained_variables, jacobian = TRUE, jacobian_adjustmen
461453
warning("'jacobian_adjustment' is deprecated. Please use 'jacobian' instead.", call. = FALSE)
462454
jacobian <- jacobian_adjustment
463455
}
464-
if (is.null(private$model_methods_env_$model_ptr)) {
465-
stop("The method has not been compiled, please call `init_model_methods()` first",
466-
call. = FALSE)
467-
}
456+
self$init_model_methods()
468457
if (length(unconstrained_variables) != private$model_methods_env_$num_upars_) {
469458
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
470459
length(unconstrained_variables), " were provided!", call. = FALSE)
@@ -487,7 +476,6 @@ CmdStanFit$set("public", name = "hessian", value = hessian)
487476
#' @examples
488477
#' \dontrun{
489478
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
490-
#' fit_mcmc$init_model_methods()
491479
#' fit_mcmc$unconstrain_variables(list(alpha = 0.5, beta = c(0.7, 1.1, 0.2)))
492480
#' }
493481
#'
@@ -496,10 +484,7 @@ CmdStanFit$set("public", name = "hessian", value = hessian)
496484
#' [hessian()]
497485
#'
498486
unconstrain_variables <- function(variables) {
499-
if (is.null(private$model_methods_env_$model_ptr)) {
500-
stop("The method has not been compiled, please call `init_model_methods()` first",
501-
call. = FALSE)
502-
}
487+
self$init_model_methods()
503488
model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
504489
prov_par_names <- names(variables)
505490

@@ -539,11 +524,12 @@ CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_var
539524
#' @param draws A `posterior::draws_*` object.
540525
#' @param format (string) The format of the returned draws. Must be a valid
541526
#' format from the \pkg{posterior} package.
527+
#' @param inc_warmup (logical) Should warmup draws be included? Defaults to
528+
#' `FALSE`.
542529
#'
543530
#' @examples
544531
#' \dontrun{
545532
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
546-
#' fit_mcmc$init_model_methods()
547533
#'
548534
#' # Unconstrain all internal draws
549535
#' unconstrained_internal_draws <- fit_mcmc$unconstrain_draws()
@@ -560,7 +546,9 @@ CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_var
560546
#' [hessian()]
561547
#'
562548
unconstrain_draws <- function(files = NULL, draws = NULL,
563-
format = getOption("cmdstanr_draws_format", "draws_array")) {
549+
format = getOption("cmdstanr_draws_format", "draws_array"),
550+
inc_warmup = FALSE) {
551+
self$init_model_methods()
564552
if (!(format %in% valid_draws_formats())) {
565553
stop("Invalid draws format requested!", call. = FALSE)
566554
}
@@ -570,22 +558,25 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
570558
call. = FALSE)
571559
}
572560
if (!is.null(files)) {
573-
read_csv <- read_cmdstan_csv(files = files, format = "draws_matrix")
574-
draws <- read_csv$post_warmup_draws
575-
}
576-
if (!is.null(draws)) {
577-
draws <- maybe_convert_draws_format(draws, "draws_matrix")
578-
}
579-
} else {
580-
if (is.null(private$draws_)) {
581-
if (!length(self$output_files(include_failed = FALSE))) {
582-
stop("Fitting failed. Unable to retrieve the draws.", call. = FALSE)
561+
read_csv <- read_cmdstan_csv(files = files)
562+
if (inc_warmup) {
563+
draws <- posterior::bind_draws(read_csv$warmup_draws,
564+
read_csv$post_warmup_draws,
565+
along = "iteration")
566+
} else {
567+
draws <- read_csv$post_warmup_draws
568+
}
569+
} else if (!is.null(draws)) {
570+
if (inc_warmup) {
571+
message("'inc_warmup' cannot be used with a draws object. Ignoring.")
583572
}
584-
private$read_csv_(format = "draws_df")
585573
}
586-
draws <- maybe_convert_draws_format(private$draws_, "draws_matrix")
574+
} else {
575+
draws <- self$draws(inc_warmup = inc_warmup)
587576
}
588577

578+
draws <- maybe_convert_draws_format(draws, "draws_matrix")
579+
589580
chains <- posterior::nchains(draws)
590581

591582
model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
@@ -624,7 +615,6 @@ CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)
624615
#' @examples
625616
#' \dontrun{
626617
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
627-
#' fit_mcmc$init_model_methods()
628618
#' fit_mcmc$variable_skeleton()
629619
#' }
630620
#'
@@ -633,11 +623,7 @@ CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)
633623
#' [hessian()]
634624
#'
635625
variable_skeleton <- function(transformed_parameters = TRUE, generated_quantities = TRUE) {
636-
if (is.null(private$model_methods_env_$model_ptr)) {
637-
stop("The method has not been compiled, please call `init_model_methods()` first",
638-
call. = FALSE)
639-
}
640-
626+
self$init_model_methods()
641627
create_skeleton(private$model_methods_env_$param_metadata_,
642628
self$runset$args$model_variables,
643629
transformed_parameters,
@@ -662,7 +648,6 @@ CmdStanFit$set("public", name = "variable_skeleton", value = variable_skeleton)
662648
#' @examples
663649
#' \dontrun{
664650
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
665-
#' fit_mcmc$init_model_methods()
666651
#' fit_mcmc$constrain_variables(unconstrained_variables = c(0.5, 1.2, 1.1, 2.2))
667652
#' }
668653
#'
@@ -671,12 +656,8 @@ CmdStanFit$set("public", name = "variable_skeleton", value = variable_skeleton)
671656
#' [hessian()]
672657
#'
673658
constrain_variables <- function(unconstrained_variables, transformed_parameters = TRUE,
674-
generated_quantities = TRUE) {
675-
if (is.null(private$model_methods_env_$model_ptr)) {
676-
stop("The method has not been compiled, please call `init_model_methods()` first",
677-
call. = FALSE)
678-
}
679-
659+
generated_quantities = TRUE) {
660+
self$init_model_methods()
680661
skeleton <- self$variable_skeleton(transformed_parameters, generated_quantities)
681662

682663
if (length(unconstrained_variables) != private$model_methods_env_$num_upars_) {

R/utils.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,9 @@ rcpp_source_stan <- function(code, env, verbose = FALSE, ...) {
786786
}
787787

788788
expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
789+
if (rlang::is_interactive()) {
790+
message("Compiling additional model methods...")
791+
}
789792
code <- c(env$hpp_code_,
790793
readLines(system.file("include", "model_methods.cpp",
791794
package = "cmdstanr", mustWork = TRUE)))
@@ -1034,7 +1037,9 @@ expose_stan_functions <- function(function_env, global = FALSE, verbose = FALSE)
10341037
})
10351038
}
10361039
} else {
1037-
message("Compiling standalone functions...")
1040+
if (rlang::is_interactive()) {
1041+
message("Compiling standalone functions...")
1042+
}
10381043
compile_functions(function_env, verbose, global)
10391044
}
10401045
invisible(NULL)

man/fit-method-constrain_variables.Rd

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/fit-method-grad_log_prob.Rd

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/fit-method-init_model_methods.Rd

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/fit-method-log_prob.Rd

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/fit-method-unconstrain_draws.Rd

Lines changed: 5 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/fit-method-unconstrain_variables.Rd

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/fit-method-variable_skeleton.Rd

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)