Skip to content

Commit 2b04e4f

Browse files
authored
Merge pull request #811 from stan-dev/qol-improvements
Bugfixes in .stanfunctions, hessian model method, and exposing RNG functions
2 parents e4ff5d4 + 7fb88b6 commit 2b04e4f

4 files changed

Lines changed: 49 additions & 11 deletions

File tree

R/model.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ CmdStanModel <- R6::R6Class(
230230
self$functions <- new.env()
231231
self$functions$compiled <- FALSE
232232
if (!is.null(stan_file)) {
233-
assert_file_exists(stan_file, access = "r", extension = "stan")
233+
assert_file_exists(stan_file, access = "r", extension = c("stan", "stanfunctions"))
234234
checkmate::assert_flag(compile)
235235
private$stan_file_ <- absolute_path(stan_file)
236236
private$stan_code_ <- readLines(stan_file)
@@ -537,7 +537,7 @@ compile <- function(quiet = TRUE,
537537
compile_hessian_method <- FALSE
538538
}
539539

540-
temp_stan_file <- tempfile(pattern = "model-", fileext = ".stan")
540+
temp_stan_file <- tempfile(pattern = "model-", fileext = paste0(".", tools::file_ext(self$stan_file())))
541541
file.copy(self$stan_file(), temp_stan_file, overwrite = TRUE)
542542
temp_file_no_ext <- strip_ext(temp_stan_file)
543543
tmp_exe <- cmdstan_ext(temp_file_no_ext) # adds .exe on Windows

R/utils.R

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,8 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
748748
package = "cmdstanr", mustWork = TRUE)))
749749

750750
if (hessian) {
751-
code <- c(code,
751+
code <- c("#include <stan/math/mix.hpp>",
752+
code,
752753
readLines(system.file("include", "hessian.cpp",
753754
package = "cmdstanr", mustWork = TRUE)))
754755
}
@@ -758,9 +759,8 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
758759
invisible(NULL)
759760
}
760761

761-
initialize_model_pointer <- function(env, data, seed = 0) {
762-
datafile_path <- ifelse(is.null(data), "", data)
763-
ptr_and_rng <- env$model_ptr(datafile_path, seed)
762+
initialize_model_pointer <- function(env, datafile_path, seed = 0) {
763+
ptr_and_rng <- env$model_ptr(ifelse(is.null(datafile_path), "", datafile_path), seed)
764764
env$model_ptr_ <- ptr_and_rng$model_ptr
765765
env$model_rng_ <- ptr_and_rng$base_rng
766766
env$num_upars_ <- env$get_num_upars(env$model_ptr_)
@@ -863,8 +863,8 @@ prep_fun_cpp <- function(fun_start, fun_end, model_lines) {
863863
fun_body <- gsub("auto", get_plain_rtn(fun_start, fun_end, model_lines), fun_body)
864864
fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]\n", fun_body, fixed = TRUE)
865865
fun_body <- gsub("std::ostream\\*\\s*pstream__\\s*=\\s*nullptr", "", fun_body)
866-
fun_body <- gsub("boost::ecuyer1988& base_rng__", "size_t seed = 0", fun_body, fixed = TRUE)
867-
fun_body <- gsub("base_rng__,", "*(new boost::ecuyer1988(seed)),", fun_body, fixed = TRUE)
866+
fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
867+
fun_body <- gsub("base_rng__,", "*(Rcpp::XPtr<boost::ecuyer1988>(base_rng_ptr).get()),", fun_body, fixed = TRUE)
868868
fun_body <- gsub("pstream__", "&Rcpp::Rcout", fun_body, fixed = TRUE)
869869
fun_body <- paste(fun_body, collapse = "\n")
870870
gsub(pattern = ",\\s*)", replacement = ")", fun_body)
@@ -904,6 +904,30 @@ compile_functions <- function(env, verbose = FALSE, global = FALSE) {
904904
} else {
905905
rcpp_source_stan(mod_stan_funs, env, verbose)
906906
}
907+
908+
# If an RNG function is exposed, initialise a Boost RNG object stored in the
909+
# environment
910+
rng_funs <- grep("rng\\b", env$fun_names, value = TRUE)
911+
if (length(rng_funs) > 0) {
912+
rng_cpp <- system.file("include", "base_rng.cpp", package = "cmdstanr", mustWork = TRUE)
913+
rcpp_source_stan(paste0(readLines(rng_cpp), collapse="\n"), env, verbose)
914+
env$rng_ptr <- env$base_rng(seed=0)
915+
}
916+
917+
# For all RNG functions, pass the initialised Boost RNG by default
918+
for (fun in rng_funs) {
919+
if (global) {
920+
fun_env <- globalenv()
921+
} else {
922+
fun_env <- env
923+
}
924+
fundef <- get(fun, envir = fun_env)
925+
funargs <- formals(fundef)
926+
funargs$base_rng_ptr <- env$rng_ptr
927+
formals(fundef) <- funargs
928+
assign(fun, fundef, envir = fun_env)
929+
}
930+
907931
env$compiled <- TRUE
908932
invisible(NULL)
909933
}

inst/include/base_rng.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#include <Rcpp.h>
2+
#include <boost/random/additive_combine.hpp>
3+
4+
// [[Rcpp::export]]
5+
SEXP base_rng(boost::uint32_t seed = 0) {
6+
Rcpp::XPtr<boost::ecuyer1988> rng_ptr(new boost::ecuyer1988(seed));
7+
return rng_ptr;
8+
}

tests/testthat/test-model-expose-functions.R

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ test_that("Functions can be compiled with model", {
112112

113113
test_that("rng functions can be exposed", {
114114
skip_if(os_is_wsl())
115-
function_decl <- "functions { real normal_rng(real mu) { return normal_rng(mu, 1); } }"
115+
function_decl <- "functions { real wrap_normal_rng(real mu, real sigma) { return normal_rng(mu, sigma); } }"
116116
stan_prog <- paste(function_decl,
117117
paste(readLines(testing_stan_file("bernoulli")),
118118
collapse = "\n"),
@@ -122,11 +122,17 @@ test_that("rng functions can be exposed", {
122122
mod <- cmdstan_model(model, force_recompile = TRUE)
123123
fit <- mod$sample(data = data_list)
124124

125+
set.seed(10)
125126
fit$expose_functions(verbose = TRUE)
126127

127128
expect_equal(
128-
fit$functions$normal_rng(5, seed = 10),
129-
3.8269637967017344771
129+
fit$functions$wrap_normal_rng(5,10),
130+
-4.5298764235381225873
131+
)
132+
133+
expect_equal(
134+
fit$functions$wrap_normal_rng(5,10),
135+
8.1295902610102039887
130136
)
131137
})
132138

0 commit comments

Comments
 (0)