Pytorch torch.cholesky忽略异常

Pytorch torch.cholesky ignoring exception

本文关键字:异常 cholesky torch Pytorch      更新时间:2023-10-16

对于我批次上的某些矩阵,由于矩阵是单数的,我有一个例外。

L = th.cholesky(Xt.bmm(X))

cholesky_cpu:对于批次 51100:U(22,22( 为零,单数 U

由于它们对于我的用例来说很少,我想忽略异常并进一步处理它们。我将结果计算设置为nan是否可能?

实际上,如果我catch异常并使用continue它仍然没有完成批处理其余部分的计算。

同样的情况也发生在 Pytorch libtorch 的C++中。

在执行cholesky分解时,PyTorch依赖于LAPACK作为CPU张量,MAGMA用于CUDA张量。在用于调用 LAPACK 的 PyTorch 代码中,批处理只是迭代,分别在每个矩阵上调用 LAPACK 的zpotrs_函数。在用于调用MAGMA的PyTorch代码中,整个批处理使用MAGMA的magma_dpotrs_batched进行处理,这可能比单独迭代每个矩阵更快。

AFAIK 没有办法向 MAGMA 或 LAPACK 指示不提出例外(尽管公平地说,我不是这些软件包的专家(。由于MAGMA可能以某种方式利用批处理,因此我们可能不希望默认为迭代方法,因为如果不执行批处理的cholesky,我们可能会失去性能。

一种可能的解决方案是首先尝试执行批量 cholesky 分解,如果失败,那么我们可以对批处理中的每个元素执行 cholesky 分解,将失败的条目设置为 NaN。

def cholesky_no_except(x, upper=False, force_iterative=False):
success = False
if not force_iterative:
try:
results = torch.cholesky(x, upper=upper)
success = True
except RuntimeError:
pass
if not success:
# fall back to operating on each element separately
results_list = []
x_batched = x.reshape(-1, x.shape[-2], x.shape[-1])
for batch_idx in range(x_batched.shape[0]):
try:
result = torch.cholesky(x_batched[batch_idx, :, :], upper=upper)
except RuntimeError:
# may want to only accept certain RuntimeErrors add a check here if that's the case
# on failure create a "nan" matrix
result = float('nan') + torch.empty(x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype)
results_list.append(result)
results = torch.cat(results_list, dim=0).reshape(*x.shape)
return results

如果您希望在cholesky分解过程中出现异常,则可能需要使用force_iterative=True跳过尝试使用批处理版本的初始调用,因为在这种情况下,此函数可能只是在第一次尝试时浪费时间。

我不知道这与发布的其他解决方案相比速度如何,但它可能会更快。

首先使用torch.det来确定批次中是否存在任何奇异矩阵。然后屏蔽掉这些矩阵。

output = Xt.bmm(X)
dets = torch.det(output)
# if output is of shape (bs, x, y), dets will be of shape (bs)
bad_idxs = dets==0 #might want an allclose here
output[bad_idxs] = 1. # fill singular matrices with 1s
L = torch.cholesky(output)

在你可能需要处理你用 1 填充的奇异矩阵之后,你有它们的索引值,所以很容易抓取或排除它们。

根据论坛Pytorch Discuss无法捕获异常。

不幸的是,解决方案是实现我自己的简单批处理cholesky(th.cholesky(..., upper=False)(,然后使用th.isnan处理Nan值。

import torch as th
# nograd cholesky
def cholesky(A):
L = th.zeros_like(A)
for i in range(A.shape[-1]):
for j in range(i+1):
s = 0.0
for k in range(j):
s = s + L[...,i,k] * L[...,j,k]
L[...,i,j] = th.sqrt(A[...,i,i] - s) if (i == j) else 
(1.0 / L[...,j,j] * (A[...,i,j] - s))
return L