Rcpp NumericMatrix-如何擦除行/列

Rcpp NumericMatrix - how to erase a row / column?

本文关键字:擦除 NumericMatrix- 何擦除 Rcpp      更新时间:2023-10-16

当我学习Rcpp类/数据结构时,一个新手问题是:是否有成员函数可以擦除Rcpp::NumericMatrix类对象的行/列?(或者其他类型的type **Matrix——我假设它是一个模板类)?

library(Rcpp)
cppFunction('
  NumericMatrix sub1 {NumericMatrix x, int& rowID, int& colID) {
    // let's assume separate functions for rowID or colID
    // but for the example case here
    x.row(rowID).erase(); // ??? does this type of member function exist?
    x.col(colID).erase(); // ???
    return x;
}')

如果这种类型的成员函数不存在,这个怎么样?

cppFunction('NumericMatrix row_erase (NumericMatrix& x, int& rowID) {
  // a similar function would exist for removing a column.
  NumericMatrix x2(Dimension(x.nrow()-1, x.ncol());
  int iter = 0; // possibly make this a pointer?
  for (int i = 0; i < x.nrow(); i++) {
    if (i != rowID) {
      x2.row(iter) = x.row(i);
      iter++;
    }
  }
  return x2;
}')

或者我们希望删除一组行/列:

cppFunction('NumericMatrix row_erase (NumericMatrix& x, IntegerVector& rowID) {
  // a similar function would exist for removing a column.
  rowID = rowID.sort();
  NumericMatrix x2(Dimension(x.nrow()- rowID.size(), x.ncol());
  int iter = 0; // possibly make this a pointer?
  int del = 1; // to count deleted elements
  for (int i = 0; i < x.nrow(); i++) {
    if (i != rowID[del - 1])
      x2.row(iter) = x.row(i);
      iter++;
    } else {
      del++;
    }
  }
  return x2;
}')

使用RcppArmadillo怎么样?我认为代码的意图会更加清晰。。。

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
using namespace arma;
// [[Rcpp::export]]
mat sub1( mat x, uword e) {
  x.shed_col(e-1);
  x.shed_row(e-1);
  return x;
}
/*** R
sub1( matrix(1:9,3), 2 )
*/

> sub1( matrix(1:9,3), 2 )
     [,1] [,2]
[1,]    1    7
[2,]    3    9

是的,这两种方法都有效(修复了我上面的打字错误)。不过,我在尝试用Rcpp::NumericMatrix::iterator iter替换int iter时出现了转换错误。有什么解决办法吗?

注意,我们不需要row_erase(NumericMatrix& x, int& ref),因为这是row_erase(NumericMatrix& x, IntegerVector& ref)的特殊情况。

NumericMatrix row_erase (NumericMatrix& x, IntegerVector& rowID) {
  rowID = rowID.sort();
  NumericMatrix x2(Dimension(x.nrow()- rowID.size(), x.ncol()));
  int iter = 0; 
  int del = 1; // to count deleted elements
  for (int i = 0; i < x.nrow(); i++) {
    if (i != rowID[del - 1]) {
      x2.row(iter) = x.row(i);
      iter++;
    } else {
      del++;
    }
  }
  return x2;
}
NumericMatrix col_erase (NumericMatrix& x, IntegerVector& colID) {
  colID = colID.sort();
  NumericMatrix x2(Dimension(x.nrow(), x.ncol()- colID.size()));
  int iter = 0; 
  int del = 1; 
  for (int i = 0; i < x.ncol(); i++) {
    if (i != colID[del - 1]) {
      x2.col(iter) = x.column(i);
      iter++;
    } else {
      del++;
    }
  }
  return x2;
}