推力的"+"运算符过载,有什么想法吗?

Overloading "+" operator for Thrust, any ideas?

本文关键字:什么 运算符      更新时间:2023-10-16

我正在使用CUDA和Thrust。我发现输入thrust::transform [plus/minus/divide]很乏味,所以我只想重载一些简单的操作符。

如果我能做到:

thrust::[host/device]_vector<float> host;
thrust::[host/device]_vector<float> otherHost;
thrust::[host/device]_vector<float> result = host + otherHost;

下面是+的示例代码片段:

template <typename T>
__host__ __device__ T& operator+(T &lhs, const T &rhs) {
    thrust::transform(rhs.begin(), rhs.end(),
                      lhs.begin(), lhs.end(), thrust::plus<?>());
    return lhs;
}

然而,thrust::plus<?>没有正确过载,或者我没有正确地做它…一个或另一个。(如果重载简单操作符是一个坏主意,请解释原因)。最初,我认为我可以用typename T::iterator之类的东西重载?占位符,但这不起作用。

我不确定如何用vector 的类型和vector迭代器的类型重载+操作符。这有道理吗?

谢谢你的帮助!

这似乎有效,其他人可能有更好的主意:

#include <ostream>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/transform.h>
#include <thrust/functional.h>
#include <thrust/copy.h>
#include <thrust/fill.h>
#define DSIZE 10

template <typename T>
thrust::device_vector<T>  operator+(thrust::device_vector<T> &lhs, const thrust::device_vector<T> &rhs) {
    thrust::transform(rhs.begin(), rhs.end(),
                      lhs.begin(), lhs.begin(), thrust::plus<T>());
    return lhs;
}
template <typename T>
thrust::host_vector<T>  operator+(thrust::host_vector<T> &lhs, const thrust::host_vector<T> &rhs) {
    thrust::transform(rhs.begin(), rhs.end(),
                      lhs.begin(), lhs.begin(), thrust::plus<T>());
    return lhs;
}
int main() {

  thrust::device_vector<float> dvec(DSIZE);
  thrust::device_vector<float> otherdvec(DSIZE);
  thrust::fill(dvec.begin(), dvec.end(), 1.0f);
  thrust::fill(otherdvec.begin(), otherdvec.end(), 2.0f);
  thrust::host_vector<float> hresult1 = dvec + otherdvec;
  std::cout << "result 1: ";
  thrust::copy(hresult1.begin(), hresult1.end(), std::ostream_iterator<float>(std::cout, " "));  std::cout << std::endl;
  thrust::host_vector<float> hvec(DSIZE);
  thrust::fill(hvec.begin(), hvec.end(), 5.0f);
  thrust::host_vector<float> hresult2 = hvec + hresult1;

  std::cout << "result 2: ";
  thrust::copy(hresult2.begin(), hresult2.end(), std::ostream_iterator<float>(std::cout, " "));  std::cout << std::endl;
  // this line would produce a compile error:
  // thrust::host_vector<float> hresult3 = dvec + hvec;
  return 0;
}

注意,在这两种情况下,我都可以为结果指定主机或设备向量,因为thrust将看到差异并自动生成必要的复制操作。因此,我的模板中的结果向量类型(主机,设备)并不重要。

还要注意,模板定义中的thrust::transform函数参数不太正确。