在主机或设备上运行 thrust::max_element

Is thrust::max_element run on host or device?

本文关键字:thrust max element 运行 主机      更新时间:2023-10-16

我需要在存储在设备上的长数组中找到最大元素。 我想我可以使用 thrust::max_element 来做到这一点。 我在代码下面的代码中的while循环中调用thrust::max_element。 我只是给它两个设备指针(请注意,real只是float的typedef(。 我不能只传递推力::max_element设备指针吗? 它是否试图在主机上找到最大元素? 我问这个是因为在那之后我的代码因 seg 错误而失败。

int main()
{
    cuda_error(cudaSetDevice(1), "set device");
    const size_t DIM = 50;
    real* grid_d;
    real* next_grid_d;
    cuda_error(cudaMalloc(&grid_d, sizeof(real) * DIM * DIM * DIM), "malloc grid");
    cuda_error(cudaMalloc(&next_grid_d, sizeof(real) * DIM * DIM * DIM), "malloc next grid");
    cuda_error(cudaMemset(grid_d, 0, sizeof(real) * DIM * DIM * DIM), "memset grid");
    ConstantSum point_charge(0.3, DIM / 2, DIM / 2, DIM / 2);
    ConstantSum* point_charge_d;
    cuda_error(cudaMalloc(&point_charge_d, sizeof(ConstantSum)), "malloc constant sum");
    cuda_error(cudaMemcpy(point_charge_d, &point_charge, sizeof(ConstantSum), cudaMemcpyHostToDevice), "memset constant sum");
    real max_err;
    do
    {
        compute_next_grid_kernel<<< DIM, dim3(16, 16) >>>(grid_d, next_grid_d, DIM, point_charge_d, 1);
        cuda_error(cudaGetLastError(), "kernel launch");
        max_err = *thrust::max_element(grid_d, grid_d + DIM * DIM * DIM);
        std::swap(grid_d, next_grid_d);
    }
    while(max_err > 0.1);
    real* frame = new real[DIM * DIM];
    cuda_error(cudaMemcpy(frame, grid_d + DIM * DIM * (DIM / 2), DIM * DIM * sizeof(real), cudaMemcpyDeviceToHost), "memcpy frame");
    cuda_error(cudaFree(grid_d), "free grid");
    cuda_error(cudaFree(next_grid_d), "free next grid");
    cuda_error(cudaFree(point_charge_d), "free point charge");
    for(int i = 0; i < DIM; i++)
    {
        for(int j = 0; j < DIM; j++)
        {
            std::cout << frame[DIM * i + j] << "t";
        }
        std::cout << "n";
    }
    delete[] frame;
    return 0;
}

通常,thrust 使用传递的迭代器的类型来确定算法后端是否会在主机或设备上运行(最新版本中也有无标记和显式执行策略选择,但这是不同的讨论(。

在您的情况下,由于grid_d主机指针(无论其是主机还是设备地址无关紧要(,因此 thrust 将尝试在主机上运行算法。这是段错误的源,您正在尝试访问主机上的设备地址。

要完成这项工作,您需要将指针投射到一个 thrust::dev_ptr ,如下所示:

thrust::device_ptr<real> grid_start = thrust::device_pointer_cast(grid_d);
thrust::device_ptr<real> grid_end= thrust::device_pointer_cast(grid_d + DIM * DIM * DIM);
auto max_it = thrust::max_element(grid_start, grid_end);
max_error = *max_it;

[警告,在浏览器中编写,从未编译或测试,使用风险自负]

通过传递thrust::dev_ptr,将进行正确的标记选择,并且闭包将在设备上运行。

另一种不强制转换的解决方案是指定执行策略device

thrust::max_element(thrust::device, grid_d, grid_d + DIM * DIM * DIM);

不,显式执行策略控制仅在 Thrust 1.7 及更高版本上受支持。