在Rcpp中使用多项式

Using rmultinom with Rcpp

本文关键字:多项式 Rcpp      更新时间:2023-10-16

我想在c++代码中使用R函数rmultinom与Rcpp一起使用。我得到了一个关于没有足够参数的错误-我不熟悉这些参数应该是什么,因为它们不对应于R中函数使用的参数。我也没有任何运气使用"::Rf_foo"语法从Rcpp代码访问R函数。

下面是我的代码的简化版本(是的,我正在编写吉布斯采样器)。

#include <Rcpp.h>                                                                                                                                     
using namespace Rcpp;                                                                                                                                 
// C++ implementation of the R which() function.                                                                                                      
int whichC(NumericVector x, double val) {                                                                                                             
  int ind = -1;                                                                                                                                       
  int n = x.size();                                                                                                                                   
  for (int i = 0; i < n; ++i) {                                                                                                                       
    if (x[i] == val) {                                                                                                                                
      if (ind == -1) {                                                                                                                                
        ind = i;                                                                                                                                      
      } else {                                                                                                                                        
        throw std::invalid_argument( "value appears multiple times." );                                                                               
      }                                                                                                                                               
    } // end if                                                                                                                                       
  } // end for                                                                                                                                        
  if (ind != -1) {                                                                                                                                    
    return ind;                                                                                                                                       
  } else {                                                                                                                                            
    throw std::invalid_argument( "value doesn't appear here!" );                                                                                      
    return -1;                                                                                                                                        
  }                                                                                                                                                   
}                                                                                                                                                     
// [[Rcpp::export]]                                                                                                                                   
int multSample(double p1, double p2, double p3) {                                                                                                     
  NumericVector params(3);                                                                                                                            
  params(0) = p1;                                                                                                                                     
  params(1) = p2;                                                                                                                                     
  params(2) = p3;                                                                                                                                     
  // HERE'S THE PROBLEM.                                                                                                                              
  RObject sampled = rmultinom(1, 1, params);                                                                                                          
  int out = whichC(as<NumericVector>(sampled), 1);                                                                                                    
  return out;                                                                                                                                         
}

我是c++的新手,所以我意识到很多代码可能是笨拙和低效的。我愿意听取关于如何改进我的c++代码的建议,但我的首要任务是听取有关多项业务的建议。谢谢!

顺便说一句,我很抱歉与这篇文章相似,但是

  1. 答案不适合我的目的
  2. 差异可能足以保证一个不同的问题(你这样认为吗?)
  3. 这个问题是一年前发布并回答的。

下面是user95215修改的答案,以便它可以编译,并且是Rcpp风格的另一个版本:

#include <Rcpp.h>
using namespace Rcpp;
// [[Rcpp::export]]
IntegerVector oneMultinomC(NumericVector probs) {
    int k = probs.size();
    SEXP ans;
    PROTECT(ans = Rf_allocVector(INTSXP, k));
    probs = Rf_coerceVector(probs, REALSXP);
    rmultinom(1, REAL(probs), k, &INTEGER(ans)[0]);
    UNPROTECT(1);
    return(ans);
}
// [[Rcpp::export]]
IntegerVector oneMultinomCalt(NumericVector probs) {
    int k = probs.size();
    IntegerVector ans(k);
    rmultinom(1, probs.begin(), k, ans.begin());
    return(ans);
}

如果我尝试编译您的代码,我得到编译器错误:

> Rcpp::sourceCpp('~/scratch/multSample.cpp')
multSample.cpp:33:21: error: no matching function for call to 'rmultinom'
  RObject sampled = rmultinom(1, 1, params);
                    ^~~~~~~~~
/Library/Frameworks/R.framework/Resources/include/Rmath.h:449:6: note: candidate function not viable: requires 4 arguments, but 3 were provided
void    rmultinom(int, double*, int, int*);
        ^
1 error generated.

正如它所暗示的那样,您没有正确指定参数。请注意,与其他函数相比,rmultinom接口有点尴尬:它填充*rn所指向的内存,而不是返回一个新对象(带有自己新分配的内存)。

如果你看一下R源码,你会看到接口,你可以看到它在这里被使用的一个例子(事实上,stats做了一个包装函数,做了一些更多的参数检查和什么的)。注意这里的用法:

rmultinom(size, REAL(prob), k, &INTEGER(ans)[ik]);

换句话说,它通过将该内存的地址传递给rmultinom函数来填充名为ansINTSXP

因此,如果您想使用Rcpp中的这个函数,您将不得不做类似的事情——但也许这需要类似的糖矢量化处理,以避免该接口的丑陋。

你可以试着这样做:

IntegerMatrix sampled(nrow, ncol);
rmultinom(1, 1, params, sampled.begin());

Kevin提供的示例和链接使我找到了一个有效的答案。有一些类型的争论。我写了一个函数,它允许你从一个多项分布中抽样一个向量。代码如下。

#include <Rcpp.h>
using namespace Rcpp;
// [[Rcpp::export]]
NumericVector oneMultinomC(NumericVector probs) {
    int k = probs.size();
    SEXP ans;
    PROTECT(ans = RF_allocVector(INTSXP, k));
    probs = RF_coerceVector(probs, REALSXP);
    rmultinom(1, REAL(probs), k, &INTEGER(ans)[0]);
    UNPROTECT(1);
    return ans;
}
这里发生的事我一半都不明白。特别是,我不理解"rmultinom"的第四个参数。我知道它是一个指向存储输出的内存位置的指针,但我不理解'[0]'位。尽管如此,它还是有效的。孩子们,去取样吧。