CUDA矩阵类的运算符()重载

Overload of operator() for a CUDA matrix class

本文关键字:重载 运算符 CUDA      更新时间:2023-10-16

我有CPU和GPU(CUDA)矩阵类,我想重载operator(),这样我就可以读取或写入矩阵的各个元素。

对于CPU矩阵类,我可以通过做到这一点

OutType & operator()(const int i) { return data_[i]; }

(读取)和

OutType operator()(const int i) const { return data_[i]; }

(书写)。对于GPU矩阵类,我能够过载operator()以供读取

__host__ OutType operator()(const int i) const { OutType d; CudaSafeCall(cudaMemcpy(&d,data_+i,sizeof(OutType),cudaMemcpyDeviceToHost)); return d; }

但我无法在写作上做到这一点。有人能提供任何线索来解决这个问题吗?

CPU的写入情况返回data_[i]的引用,因此分配作业由内置的C++operator=执行。我不知道我怎么能利用CUDA的同样机制。

谢谢。

您可以创建一个单独的类,该类具有重载的赋值运算符和类型转换运算符,并模拟引用行为:

class DeviceReferenceWrapper
{
public:
    explicit DeviceReferenceWrapper(void* ptr) : ptr_(ptr) {}
    DeviceReferenceWrapper& operator =(int val)
    {
        cudaMemcpy(ptr_, &val, sizeof(int), cudaMemcpyHostToDevice);
        return *this;
    }
    operator int() const
    {
        int val;
        cudaMemcpy(&val, ptr_, sizeof(int), cudaMemcpyDeviceToHost);
        return val;
    }
private:
    void* ptr_;
};

并将其用于矩阵类

class Matrix
{
    DeviceReferenceWrapper operator ()(int i)
    {
        return DeviceReferenceWrapper(data + i);
    }
};