在 pytorch c++ API 中添加填充

Add padding in pytorch c++ API

本文关键字:添加 填充 API pytorch c++      更新时间:2023-10-16

>我有一个维度(1,3, 375, 1242)的张量。我想通过添加填充来重塑它以(1, 3, 384, 1248)。如何在 Pytorch c++ API 中做到这一点。提前谢谢你。

target = torch.zeros(1, 3, 384, 1248)
source = torch.ones(1, 3, 375, 1242)
target[: , : , :375, :1242] = source
您可以使用

torch::constant_pad_nd

torch::Tensor source = torch::ones(torch::IntList{1, 3, 375, 1242});
// add 6 zeros to the last dimension and 9 zeros to the third dimension
torch::Tensor target = torch::constant_pad_nd(target, IntList{0, 6, 0, 9}, 0);