使用 cblas_dgemm 计算伪逆的问题

Issue using cblas_dgemm to compute pseudoinverse

本文关键字:问题 计算 cblas dgemm 使用      更新时间:2023-10-16

我正在尝试使用英特尔 MKL 计算存储在LAPACK_ROW_MAJOR布局中的矩阵的伪逆。

A_5x4 =
1     2     3     4
5     6     7     8
9    10    11    12
13    14    15    16
17    18    19    20

我正在使用gesvd函数来计算SVD的紧凑形式:

info = LAPACKE_dgesvd(LAPACK_ROW_MAJOR, 'S', 'S', m, n, A, lda, s, u, ldu, vt, ldvt, superb);

其中m=5n=4lda=4ldu=5ldvt=4。我可以成功地使用MKL函数来获取矩阵的SVD,A = U*S*VT

u_5x4 = 
0.0965          0.7686          0.6323          0.0034
0.2455          0.4896         -0.6208          0.0412
0.3945          0.2107         -0.3285         -0.4681
0.5435         -0.0683         -0.0097          0.7989
0.6924         -0.3472          0.3267         -0.3754
s_4x1 = 
53.520222
2.363426
0.000000
0.000000
vt_4x4 = 
0.4430          0.4799          0.5167          0.5536
-0.7097         -0.2640          0.1816          0.6273
0.0912         -0.5242          0.7747         -0.3417
0.5401         -0.6521         -0.3160          0.4280

因为s只有两个非零元素,所以我需要考虑u的前两列,v的两列(不是vt(以及s元素的倒数

v_4x2_needed_for_pinv = 
0.4430    0.4799
-0.7097   -0.2640
0.0912   -0.5242
0.5401   -0.6521
u_2x5_needed_for_pinv = 
0.0965   0.2455   0.3945   0.5435   0.6924
0.7686   0.4896   0.2107  -0.0683  -0.3472

我可以毫无问题地用for-loop执行矩阵乘法并计算 A 的伪逆。但是,我对使用dscalcblas_dgemm非常感兴趣,主要是因为要计算逆矩阵的实际矩阵非常大。

我能够成功地使用dscal计算出来,并将 V 的前两列乘以 S 的倒数:

MKL_INT  k = ((m) < (n) ? (m) : (n));
// Computing VT = vt*(s^-1)
MKL_INT incx = 1;
MKL_INT r = 0;
for (int i = 0; i < k; i++)
{
double ss;
if (s[i] > 1.0e-9)
{
ss = 1.0 / s[i];
r++;
}
else
ss = s[i];
dscal(&n, &ss, &vt[i*n], &incx); // this replaces vt with new values.
}

我的问题是使用u_2x5_needed_for_pinv执行矩阵乘法v_4x2_needed_for_pinv,这是LAPACKE_dgesvd计算的uvt数组的子集。有人可以帮我弄清楚如何使用cblas_dgemm吗?我将不胜感激。

我尝试了以下内容,函数的输入对我来说很有意义,但它不起作用

// inv(A) = VT^T * U^T = V * U^T
double* inva = (double*)malloc(n*m * sizeof(double));
double alpha = 1.0, beta = 0.0;
MKL_INT ld_inva = n;
cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, n, m, r, alpha, vt, n, u, m, beta, inva, ld_inva);

其中r=2因为s只有两个非零元素(53.5202222.363426(。

由于最后三个奇异值为零,我们可以说 SVD 产生:

  • u(5,2)ldu=4
  • vt(2,4)ldvt=4
  • invA(5,4)

逆的计算公式为 invA = vt^T * invS * u^T 并且跟随您的循环可以转换为 invA = (invS * vt(^T * u^T

MKL_INT ma = mu = 5;
MKL_INT na = nvt = 4;
MKL_INT nu = mvt = ms = 2;
MKL_INT lda = ldu = ldvt = 4;
// vt = (invS * vt)
for(MKL_INT i=0; i<ms; i++){
cblas_dscal (nvt, s[i], vt+(i*ldvt), 1);
}
// invA = vt^T * u^T
cblas_dgemm (CblasRowMajor, CblasTrans, CblasTrans, ma, na, nu, 1.0, vt, ldvt, u, ldu, 0.0, invA, lda);