对于采样p r o b l e m,Rcpp 比 R 慢

Rcpp slower than R for a sampling p‌r‌o‌b‌l‌e‌m

本文关键字:Rcpp 于采样 采样      更新时间:2023-10-16

问题设置

问题包括从一年的 365 天n天中抽样,其方式是

  • 天数由均匀概率分布绘制
  • 天数符合由min_dist给出的最小距离
  • 结果以数字向量形式给出

对于n= 12min_dist= 20正确的结果可能是向量[1] 4 43 69 97 129 161 192 215 243 285 309 343当这个向量diff[1] 39 26 28 32 32 31 23 28 42 24 34时,所有值都大于或等于min_dist= 20

问题

我已经解决了这个问题

  • 本机R中的函数sample_r()
  • 使用出色的Rcpp接口包在c++sample_cpp()功能

c++解决方案的速度要慢得多(在我的 Mac 上为 60 倍)。我是一个Rccp新手,因此我自己的研究能力有限 - 请原谅。我该怎么做才能重构c++代码,使其比本机R代码更快?

可重现的代码 (.cpp file)

#include <Rcpp.h>
using namespace Rcpp;
using namespace std;
// [[Rcpp::export]]
IntegerVector sample_cpp(int n, int min_dist= 5L, int seed= 42L) {

IntegerVector res_empty= Rcpp::rep(NA_INTEGER, n);
IntegerVector res;
IntegerVector available_days_full= Rcpp::seq(1, 365);
IntegerVector available_days;
IntegerVector forbidden_days;
IntegerVector forbidden_space = Rcpp::seq(-(min_dist-1), (min_dist-1));
bool fail;
Environment base("package:base");
Function set_seed = base["set.seed"];
set_seed(seed);
do {
res= res_empty;
available_days = available_days_full;
fail= FALSE;
for(int i= 0; i < n; ++i) {
res[i]= sample(available_days, 1, FALSE)[0];
forbidden_days= res[i]+forbidden_space;
available_days= setdiff(available_days, forbidden_days);
if(available_days.size() <= 1){
fail= TRUE;
break;
}
}
}
while(fail== TRUE);
std::sort(res.begin(), res.end());
return res;
}

/*** R
# c++ function
(r= sample_cpp(n= 12, min_dist= 20, seed=1))
diff(r)

# R function
sample_r= function(n= 12, min_dist=5, seed= 42){
if(n*min_dist>= 365) stop("Infeasible.")
set.seed(seed)
repeat{
res= numeric(n)
fail= FALSE
available_days= seq(365)
for(i in seq(n)){
if(length(available_days) <= 1){
fail= TRUE
break()
}
res[i]= sample(available_days, 1)
forbidden_days= res[i]+(-(min_dist-1):(min_dist-1))
available_days= setdiff(available_days, forbidden_days)
}
if(fail== FALSE) return(sort(res))
}
}
(r= sample_r(n= 12, min_dist= 20, seed= 40))
diff(r)
# Benchmark
library(rbenchmark)
benchmark(cpp= sample_cpp(n= 12, min_dist = 28),
r= sample_r(n= 12, min_dist = 28),
replications = 50)[,1:4]
*/

基准:

test replications elapsed relative
1  cpp           50  28.005   63.217
2    r           50   0.443    1.000

>编辑:好的,我尝试优化(尽我所能c++),c++实现仍然落后,但现在只是勉强。

#include <Rcpp.h>
using namespace Rcpp;
using namespace std;
// [[Rcpp::export]]
IntegerVector sample_cpp(int n, int min_dist= 5L, int seed= 42L) {

IntegerVector res;
IntegerVector available_days;
IntegerVector forbidden_days;
IntegerVector forbidden_space = Rcpp::seq(-(min_dist-1), (min_dist-1));
bool fail;
Environment base("package:base");
Function set_seed = base["set.seed"];
set_seed(seed);
do {
res= Rcpp::rep(NA_INTEGER, n);
available_days = Rcpp::seq(1, 365);
fail= FALSE;
for(int i= 0; i < n; ++i) {
if(available_days.size() < n-i){
fail= TRUE;
break;
}
int temp= sample(available_days, 1, FALSE)[0];
res[i]= temp;
forbidden_days= unique(pmax(0, temp + forbidden_space));
available_days= setdiff(available_days, forbidden_days);
}
}
while(fail== TRUE);
std::sort(res.begin(), res.end());
return res;
}

/*** R
# R function
sample_r= function(n= 12, min_dist=5, seed= 42){
if(n*min_dist>= 365) stop("Infeasible.")
set.seed(seed)
repeat{
res= numeric(n)
fail= FALSE
available_days= seq(365)
for(i in seq(n)){
if(length(available_days) <= n-i){
fail= TRUE
break()
}
res[i]= sample(available_days, 1)
forbidden_days= res[i]+(-(min_dist-1):(min_dist-1))
available_days= setdiff(available_days, forbidden_days)
}
if(fail== FALSE) return(sort(res))
}
}
# Benchmark
library(rbenchmark)
benchmark(cpp= sample_cpp(n= 12, min_dist = 28),
r= sample_r(n= 12, min_dist = 28),
replications = 50)[,1:4]
*/

基准:

test replications elapsed relative
1  cpp           50   0.643    1.475
2    r           50   0.436    1.000

您可以通过一次采样尽可能多的天数来优化 R 版本。

以下代码比你的代码快。从统计上讲,我对循环前的大部分时间进行了抽样。剩余的天数在循环中采样,但循环可能只运行一次。也许两次。

此外,使用 Rcpp 很容易重写。

sample_r2= function(n = 12, min_dist = 5, seed = 42)
{
available_days = seq(365)
res = sort(sample(available_days, n))
y = diff(res)
res = res[y >= min_dist]
while (length(res) < n)
{
forbidden_days = sapply(res, function(x){ x + -(min_dist-1):(min_dist-1) } )
available_days = setdiff(available_days, forbidden_days)
days = sample(available_days, n - length(res))
res = sort(c(res, days))
y = diff(res)
res = res[y >= min_dist]
}
return(res)
}

顺便说一句,也许我的代码中存在一些问题。但我认为这个想法是正确的。