PyTorch C++扩展:访问Half-Tensors的数据

PyTorch C++ Extensions: Accessing data for Half Tensors

本文关键字:Half-Tensors 数据 访问 C++ 扩展 PyTorch      更新时间:2023-10-16

我正在尝试使用C++Tensor API为PyTorch编写C++/CUDA扩展,我希望我的代码能同时使用float32和float16(半精度(。我不知道如何访问来自Python的半张量的数据指针。

以下是我对浮动张量的处理方法:

// Access data pointer for float Tensor A
torch::Tensor A;
float* ptr = A.data<float>();

以下是我对半张量的尝试:

// CUDA float 16 type
// undefined symbol: _ZNK2at6Tensor4dataI6__halfEEPT_v
A.data<__half>();
// PyTorch float16 type
// error: no instance of function template "at::Tensor::data" 
A.data<torch::ScalarType::Half>();
// Casting to __half*
// This compiles but throws and error if the requested pointer type doesn't match the Tensor type:
// RuntimeError: expected scalar type Float but found Half
(__half*)(A.data<float>());

我试着查看C++api源代码,但找不到其他类似float16类型的代码。

系统信息:Python 3.6.2PyTorch 1.0.1

正确的类型是at::Half