多项目替换在cuda推力

multi item replacing in cuda thrust

本文关键字:cuda 推力 替换 项目      更新时间:2023-10-16

我有一个设备向量a,B,C如下:

A = [1,1,3,3,3,4,4,5,5]
B = [1,3,5]
C = [2,8,6]

我想用C中相应的元素替换a中的每个B。如:

  • 1被2取代,
  • 3被8取代,
  • 5被6取代

,从而得到如下结果

Result = [2,2,8,8,8,4,4,6,6]

我如何在cuda推力或cuda c++中实现它的任何方式。我发现了一次性替换单个元素的thrust::replace。因为我需要替换大量的数据,所以每次只替换一个数据就成了瓶颈。

这可以通过首先构建一个map,然后应用一个自定义函子来查询该map来有效地完成。

示例代码执行以下步骤:

  1. 获取C中最大的元素。

  2. 创建大小为largest_element的映射向量。将新值复制到旧值的位置

  3. mapper函子应用于A。这个函子从映射向量中读取new_value。如果new_value而不是 0,则A中的值将被替换为新的值。这里假设C不包含0。如果可以包含0,则必须使用另一个条件,例如,用-1初始化映射向量,并检查new_value != -1


#include <thrust/device_vector.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/copy.h>
#include <thrust/for_each.h>
#include <thrust/scatter.h>
#include <iostream>

#define PRINTER(name) print(#name, (name))
template <template <typename...> class V, typename T, typename ...Args>
void print(const char* name, const V<T,Args...> & v)
{
    std::cout << name << ":t";
    thrust::copy(v.begin(), v.end(), std::ostream_iterator<T>(std::cout, "t"));
    std::cout << std::endl;
}

template <typename T>
struct mapper
{
    mapper(thrust::device_ptr<const T> map) : map(map)
    {
    }
    __host__ __device__
    void operator()(T& value) const
    {
       const T& new_value = map[value]; 
       if (new_value)
       {
          value = new_value;
       }
    }
    thrust::device_ptr<const T> map;
};
int main()
{
    using namespace thrust::placeholders;
    int A[] = {1,1,3,3,3,4,4,5,5};
    int B[] = {1,3,5};
    int C[] = {2,8,6};
    int size_data    = sizeof(A)/sizeof(A[0]);
    int size_replace = sizeof(B)/sizeof(B[0]);
    // copy demo data to GPU
    thrust::device_vector<int> d_A (A, A+size_data);
    thrust::device_vector<int> d_B (B, B+size_replace);
    thrust::device_vector<int> d_C (C, C+size_replace);
    PRINTER(d_A);
    PRINTER(d_B);
    PRINTER(d_C);
    int largest_element = d_C.back();
    thrust::device_vector<int> d_map(largest_element);
    thrust::scatter(d_C.begin(), d_C.end(), d_B.begin(), d_map.begin());
    PRINTER(d_map);
    thrust::for_each(d_A.begin(), d_A.end(), mapper<int>(d_map.data()));
    PRINTER(d_A);
    return 0;
}

d_A:    1   1   3   3   3   4   4   5   5   
d_B:    1   3   5   
d_C:    2   8   6   
d_map:  0   2   0   8   0   6   
d_A:    2   2   8   8   8   4   4   6   6