@@ -748,7 +748,8 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
748748 package = " cmdstanr" , mustWork = TRUE )))
749749
750750 if (hessian ) {
751- code <- c(code ,
751+ code <- c(" #include <stan/math/mix.hpp>" ,
752+ code ,
752753 readLines(system.file(" include" , " hessian.cpp" ,
753754 package = " cmdstanr" , mustWork = TRUE )))
754755 }
@@ -758,9 +759,8 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
758759 invisible (NULL )
759760}
760761
761- initialize_model_pointer <- function (env , data , seed = 0 ) {
762- datafile_path <- ifelse(is.null(data ), " " , data )
763- ptr_and_rng <- env $ model_ptr(datafile_path , seed )
762+ initialize_model_pointer <- function (env , datafile_path , seed = 0 ) {
763+ ptr_and_rng <- env $ model_ptr(ifelse(is.null(datafile_path ), " " , datafile_path ), seed )
764764 env $ model_ptr_ <- ptr_and_rng $ model_ptr
765765 env $ model_rng_ <- ptr_and_rng $ base_rng
766766 env $ num_upars_ <- env $ get_num_upars(env $ model_ptr_ )
@@ -863,8 +863,8 @@ prep_fun_cpp <- function(fun_start, fun_end, model_lines) {
863863 fun_body <- gsub(" auto" , get_plain_rtn(fun_start , fun_end , model_lines ), fun_body )
864864 fun_body <- gsub(" // [[stan::function]]" , " // [[Rcpp::export]]\n " , fun_body , fixed = TRUE )
865865 fun_body <- gsub(" std::ostream\\ *\\ s*pstream__\\ s*=\\ s*nullptr" , " " , fun_body )
866- fun_body <- gsub(" boost::ecuyer1988& base_rng__" , " size_t seed = 0 " , fun_body , fixed = TRUE )
867- fun_body <- gsub(" base_rng__," , " *(new boost::ecuyer1988(seed ))," , fun_body , fixed = TRUE )
866+ fun_body <- gsub(" boost::ecuyer1988&\\ s* base_rng__" , " SEXP base_rng_ptr " , fun_body )
867+ fun_body <- gsub(" base_rng__," , " *(Rcpp::XPtr< boost::ecuyer1988>(base_rng_ptr).get( ))," , fun_body , fixed = TRUE )
868868 fun_body <- gsub(" pstream__" , " &Rcpp::Rcout" , fun_body , fixed = TRUE )
869869 fun_body <- paste(fun_body , collapse = " \n " )
870870 gsub(pattern = " ,\\ s*)" , replacement = " )" , fun_body )
@@ -904,6 +904,30 @@ compile_functions <- function(env, verbose = FALSE, global = FALSE) {
904904 } else {
905905 rcpp_source_stan(mod_stan_funs , env , verbose )
906906 }
907+
908+ # If an RNG function is exposed, initialise a Boost RNG object stored in the
909+ # environment
910+ rng_funs <- grep(" rng\\ b" , env $ fun_names , value = TRUE )
911+ if (length(rng_funs ) > 0 ) {
912+ rng_cpp <- system.file(" include" , " base_rng.cpp" , package = " cmdstanr" , mustWork = TRUE )
913+ rcpp_source_stan(paste0(readLines(rng_cpp ), collapse = " \n " ), env , verbose )
914+ env $ rng_ptr <- env $ base_rng(seed = 0 )
915+ }
916+
917+ # For all RNG functions, pass the initialised Boost RNG by default
918+ for (fun in rng_funs ) {
919+ if (global ) {
920+ fun_env <- globalenv()
921+ } else {
922+ fun_env <- env
923+ }
924+ fundef <- get(fun , envir = fun_env )
925+ funargs <- formals(fundef )
926+ funargs $ base_rng_ptr <- env $ rng_ptr
927+ formals(fundef ) <- funargs
928+ assign(fun , fundef , envir = fun_env )
929+ }
930+
907931 env $ compiled <- TRUE
908932 invisible (NULL )
909933}
0 commit comments