用cuBlas将一个矩阵和它的转置相乘

Multiplying a matrix with its transpose using cuBlas

本文关键字:转置 一个 cuBlas      更新时间:2023-10-16

我正试图将矩阵与其转置相乘,但我无法进行正确的sgemm调用。Sgemm有很多参数。有些像lda, ldb让我很困惑。如果我用方阵调用下面的函数,它可以工作,否则它不能工作。

/*param inMatrix: contains the matrix data in major order like [1 2 3 1 2 3]
  param rowNum: Number of rows in a matrix eg if matrix is 
                 |1  1|
                 |2  2|
                 |3  3| than rowNum should be 3*/ 
void matrixtTransposeMult(std::vector<float>& inMatrix, int rowNum)
{
    cublasHandle_t handle;
    cublasCreate(&handle);
    int colNum = (int)inMatrix.size() / rowNum;
    thrust::device_vector<float> d_InMatrix(inMatrix);
    thrust::device_vector<float> d_outputMatrix(rowNum*rowNum);
    float alpha = 1.0f;
    float beta = 0.0f;
    cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, rowNum, rowNum, colNum, &alpha,
        thrust::raw_pointer_cast(d_InMatrix.data()), colNum, thrust::raw_pointer_cast(d_InMatrix.data()), colNum, &beta,
        thrust::raw_pointer_cast(d_outputMatrix.data()), rowNum);
    thrust::host_vector<float> result = d_outputMatrix;
    for (auto elem : result)
        std::cout << elem << ",";
    std::cout << std::endl;
    cublasDestroy(handle);
}

我错过了什么?如何使正确的sgemm调用矩阵*矩阵转置?

下面的设置为我工作,如果我错过了什么请警告我。我希望它会对某人有用。

void matrixtTransposeMult(std::vector<float>& inMatrix, int rowNum)
{
    cublasHandle_t handle;
    cublasCreate(&handle);
    int colNum = (int)inMatrix.size() / rowNum;
    thrust::device_vector<float> d_InMatrix(inMatrix);
    thrust::device_vector<float> d_outputMatrix(rowNum*rowNum);
    float alpha = 1.0f;
    float beta = 0.0f;
    cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, rowNum, rowNum, colNum, &alpha,
        thrust::raw_pointer_cast(d_InMatrix.data()), rowNum, thrust::raw_pointer_cast(d_InMatrix.data()), rowNum, &beta,
        thrust::raw_pointer_cast(d_outputMatrix.data()), rowNum);
    thrust::host_vector<float> result = d_outputMatrix;
    for (auto elem : result)
        std::cout << elem << ",";
    std::cout << std::endl;
    cublasDestroy(handle);
}