CUDA内核使用模板启动宏

CUDA kernel launch macro with templates

本文关键字:启动 内核 CUDA      更新时间:2023-10-16

我制作了一个宏来简化CUDA内核调用:

#define LAUNCH LAUNCH_ASYNC
#define LAUNCH_ASYNC(kernel_name, gridsize, blocksize, ...) 
    LOG("Async kernel launch: " #kernel_name);              
    kernel_name <<< (gridsize), (blocksize) >>> (__VA_ARGS__);
#define LAUNCH_SYNC(kernel_name, gridsize, blocksize, ...)     
    LOG("Sync kernel launch: " #kernel_name);                  
    kernel_name <<< (gridsize), (blocksize) >>> (__VA_ARGS__); 
    cudaDeviceSynchronize();                                   
    // error check, etc...

用法:

LAUNCH(my_kernel, 32, 32, param1, param2)
LAUNCH(my_kernel<int>, 32, 32, param1, param2)

这很好;有了第一个定义,我可以启用synroous调用和错误检查以进行调试。

然而,它不适用于以下多个模板参数:

LAUNCH(my_kernel<int,float>, 32, 32, param1, param3)

我在调用宏的行中得到的错误消息:

error : expected a ">"

是否可以使此宏与多个模板参数一起工作?

问题是预处理器对尖括号嵌套一无所知,所以它将它们之间的逗号解释为宏参数分隔符。

如果内核启动语法支持内核名称周围的括号(我现在不能检查,在CUDA机器上也不能),您可以这样做:

LAUNCH((my_kernel<int, float>), 32, 32, param1, param3)

您可以尝试的其他方法(基于您发布的宏)是将内核块大小和网格大小参数封装在自己的宏中:

#define KERNEL_ARGS2(grid, block) <<< grid, block >>>
#define KERNEL_ARGS3(grid, block, sh_mem) <<< grid, block, sh_mem >>>
#define KERNEL_ARGS4(grid, block, sh_mem, stream) <<< grid, block, sh_mem, stream >>>

现在你应该可以像这样使用宏了:

#define CUDA_LAUNCH(kernel_name, gridsize, blocksize, ...) 
kernel_name KERNEL_ARGS2(gridsize, blocksize)(__VA_ARGS__);

你可以像这样使用它:

CUDA_LAUNCH(my_kernel, grid_size, block_size, float* input, float* output, int size);

这将启动名为"my_kernal"的内核,该内核具有给定的网格和块大小以及输入参数。

考虑这个也会抛出错误的解决方案

inline void echoError(cudaError_t e, const char *strs) {
    char a[255];
    if (e != cudaSuccess) {
        strncpy(a, strs, 255);
        fprintf(stderr, "Failed to %s,errorCode %s",
                a, cudaGetErrorString(e));
        exit(EXIT_FAILURE);
    }
}

#define CUDA_KERNEL_DYN(kernel, bpg, tpb, shd, ...){                     
    kernel<<<bpg,tpb,shd>>>( __VA_ARGS__ );                              
    cudaError_t err = cudaGetLastError();                                
    echoError(err, #kernel);                                              
}