关于Tensorflow和PyTorch中的自定义操作

About custom operations in Tensorflow and PyTorch

本文关键字:自定义 操作 PyTorch Tensorflow 关于      更新时间:2023-10-16

我必须实现一个能量函数,称为刚性能量,如本文的等式7所示
能量函数将两个三维对象网格作为输入,并返回它们之间的能量。第一个网格是源网格,第二个网格是变形版本的源网格。在粗略的psuedo代码中,计算如下:

在源网格中的所有顶点上迭代。

  1. 对于每个顶点,计算其与相邻顶点的协方差矩阵
  2. 对计算出的协方差矩阵进行SVD,求出顶点的旋转矩阵
  3. 使用计算的旋转矩阵、原始网格中的点坐标和变形网格中的相应坐标来计算顶点的能量偏差

因此,这个能量函数要求我在网格中的每个点上迭代,并且网格可能有超过2k个这样的点。在Tensorflow中,有两种方法可以做到这一点。我可以有两个形状为(N,3)的张量,一个表示源点,另一个表示变形网格。

  1. 纯粹使用张量流张量。也就是说,使用tf.gather迭代上述张量的元素,并仅使用现有的TF运算对每个点执行计算。这种方法会非常缓慢。我以前尝试过定义迭代超过1000个点的损失函数,而图的构建本身需要太多时间来实现
  2. 添加一个新的TF OP,如TF文档中所述。这包括在CPP(和Cuda,用于GPU支持)中编写函数,并向TF注册新的OP

第一种方法写起来很容易,但速度慢得不切实际。第二种方法写起来很痛苦。

我已经使用TF 3年了,以前从未使用过PyTorch,但现在我正在考虑改用它,如果它能为这种情况提供更好的替代方案的话。

PyTorch有没有一种方法可以实现这样的损失函数?在GPU上的执行速度一样快。也就是说,用一种蟒蛇般的方式编写我自己在GPU上运行的损失函数,而我没有任何C或Cuda代码?

据我所知,您实际上是在问这个操作是否可以向量化。答案是否定的,至少不是完全否定的,因为PyTorch中的svd实现没有矢量化。

如果您展示了tensorflow实现,它将有助于理解您的起点。我不知道你说的找到顶点的旋转矩阵是什么意思,但我想这可以矢量化。这意味着svd是唯一的非矢量化操作,您可以只编写一个自定义OP,也就是矢量化svd,这可能很容易,因为它相当于在C++中的循环中调用一些库例程。

我看到的两个可能的问题来源是

  1. 如果方程7中N(i)的邻域可以具有显著不同的大小(这意味着协方差矩阵具有不同的大小,并且矢量化将需要一些肮脏的技巧)
  2. 处理网格和邻域的一般问题可能很困难。这是不规则网格的固有特性,但PyTorch支持稀疏矩阵和专用包torch_geometry,这至少有帮助