11#include < Rcpp.h>
2+ #include < rcpp_eigen_interop.hpp>
23#include < stan/model/model_base.hpp>
34#include < stan/model/log_prob_grad.hpp>
45#include < stan/model/log_prob_propto.hpp>
@@ -26,10 +27,14 @@ using json_data_t = stan::json::json_data;
2627 return std::make_shared<json_data_t >(data_context);
2728}
2829
30+ stan::model::model_base&
31+ new_model (stan::io::var_context& data_context, unsigned int seed,
32+ std::ostream* msg_stream);
33+
2934// [[Rcpp::export]]
3035Rcpp::List model_ptr (std::string data_path, boost::uint32_t seed) {
31- Rcpp::XPtr<stan_model > ptr (
32- new stan_model (*var_context (data_path), seed, &Rcpp::Rcout)
36+ Rcpp::XPtr<stan::model::model_base > ptr (
37+ & new_model (*var_context (data_path), seed, &Rcpp::Rcout)
3338 );
3439 Rcpp::XPtr<boost::ecuyer1988> base_rng (new boost::ecuyer1988 (seed));
3540 return Rcpp::List::create (
@@ -39,37 +44,56 @@ Rcpp::List model_ptr(std::string data_path, boost::uint32_t seed) {
3944}
4045
4146// [[Rcpp::export]]
42- double log_prob (SEXP ext_model_ptr, std::vector<double > upars,
43- bool jac_adjust) {
47+ double log_prob (SEXP ext_model_ptr, Eigen::VectorXd upars, bool jac_adjust) {
4448 Rcpp::XPtr<stan::model::model_base> ptr (ext_model_ptr);
45- std::vector<int > params_i;
4649 if (jac_adjust) {
47- return stan::model::log_prob_propto<true >(*ptr.get (), upars, params_i, &Rcpp::Rcout);
50+ return stan::model::log_prob_propto<true >(*ptr.get (), upars, &Rcpp::Rcout);
4851 } else {
49- return stan::model::log_prob_propto<false >(*ptr.get (), upars, params_i, &Rcpp::Rcout);
52+ return stan::model::log_prob_propto<false >(*ptr.get (), upars, &Rcpp::Rcout);
5053 }
5154}
5255
5356// [[Rcpp::export]]
54- Rcpp::NumericVector grad_log_prob (SEXP ext_model_ptr, std::vector< double > upars,
57+ Rcpp::NumericVector grad_log_prob (SEXP ext_model_ptr, Eigen::VectorXd upars,
5558 bool jac_adjust) {
5659 Rcpp::XPtr<stan::model::model_base> ptr (ext_model_ptr);
57- std::vector<double > gradients;
58- std::vector<int > params_i;
60+ Eigen::VectorXd gradients;
5961
6062 double lp;
6163 if (jac_adjust) {
62- lp = stan::model::log_prob_grad<true , true >(
63- *ptr.get (), upars, params_i, gradients);
64+ lp = stan::model::log_prob_grad<true , true >(*ptr.get (), upars, gradients);
6465 } else {
65- lp = stan::model::log_prob_grad<true , false >(
66- *ptr.get (), upars, params_i, gradients);
66+ lp = stan::model::log_prob_grad<true , false >(*ptr.get (), upars, gradients);
6767 }
68- Rcpp::NumericVector grad_rtn = Rcpp::wrap (gradients);
68+ Rcpp::NumericVector grad_rtn ( Rcpp::wrap (std::move ( gradients)) );
6969 grad_rtn.attr (" log_prob" ) = lp;
7070 return grad_rtn;
7171}
7272
73+ // [[Rcpp::export]]
74+ Rcpp::List hessian (SEXP ext_model_ptr, Eigen::VectorXd upars, bool jacobian) {
75+ Rcpp::XPtr<stan::model::model_base> ptr (ext_model_ptr);
76+
77+ auto hessian_functor = [&](auto && x) {
78+ if (jacobian) {
79+ return ptr->log_prob <true , true >(x, 0 );
80+ } else {
81+ return ptr->log_prob <true , false >(x, 0 );
82+ }
83+ };
84+
85+ double log_prob;
86+ Eigen::VectorXd grad;
87+ Eigen::MatrixXd hessian;
88+
89+ stan::math::internal::finite_diff_hessian_auto (hessian_functor, upars, log_prob, grad, hessian);
90+
91+ return Rcpp::List::create (
92+ Rcpp::Named (" log_prob" ) = log_prob,
93+ Rcpp::Named (" grad_log_prob" ) = grad,
94+ Rcpp::Named (" hessian" ) = hessian);
95+ }
96+
7397// [[Rcpp::export]]
7498size_t get_num_upars (SEXP ext_model_ptr) {
7599 Rcpp::XPtr<stan::model::model_base> ptr (ext_model_ptr);
@@ -95,12 +119,23 @@ Rcpp::List get_param_metadata(SEXP ext_model_ptr) {
95119}
96120
97121// [[Rcpp::export]]
98- std::vector< double > unconstrain_variables (SEXP ext_model_ptr, std::string init_path ) {
122+ Eigen::VectorXd unconstrain_variables (SEXP ext_model_ptr, Eigen::VectorXd variables ) {
99123 Rcpp::XPtr<stan::model::model_base> ptr (ext_model_ptr);
100- std::vector<int > params_i;
101- std::vector<double > vars;
102- ptr->transform_inits (*var_context (init_path), params_i, vars, &Rcpp::Rcout);
103- return vars;
124+ Eigen::VectorXd unconstrained_variables;
125+ ptr->unconstrain_array (variables, unconstrained_variables, &Rcpp::Rcout);
126+ return unconstrained_variables;
127+ }
128+
129+ // [[Rcpp::export]]
130+ Eigen::MatrixXd unconstrain_draws (SEXP ext_model_ptr, Eigen::MatrixXd variables) {
131+ Rcpp::XPtr<stan::model::model_base> ptr (ext_model_ptr);
132+ Eigen::MatrixXd unconstrained_draws (variables.cols (), variables.rows ());
133+ for (int i = 0 ; i < variables.rows (); i++) {
134+ Eigen::VectorXd unconstrained_variables;
135+ ptr->unconstrain_array (variables.transpose ().col (i), unconstrained_variables, &Rcpp::Rcout);
136+ unconstrained_draws.col (i) = unconstrained_variables;
137+ }
138+ return unconstrained_draws.transpose ();
104139}
105140
106141// [[Rcpp::export]]
0 commit comments