重载 C++ 中的乘法运算符

Overloading the multiplication operator in c++

本文关键字:运算符 C++ 重载      更新时间:2023-10-16

我已经为 LAPACK 编写了一个C++接口,但我遇到了一些内存问题,这些问题让我重新考虑了一些运算符重载。

现在,我已经在类定义之外重载了运算符*(但作为矩阵类的朋友(,它接受两个矩阵对象,为第三个对象分配适当的维度,使用 D(GE/SY(MM 计算乘积(存储到新分配矩阵的内部存储中(,然后返回指向该新矩阵的指针。即

class Matrix {
...
friend Matrix* operator*(const Matrix&, const Matrix&);
...
}
Matrix* operator*(const Matrix& m1, const Matrix& m2) {
  Matrix *prod = new Matrix(m1.rows_, m2.cols_);
  if(m1.cols_!=m2.rows_) {
    throw 3008;
  } else {
    double alpha = 1.0;
    double beta = 0.0;
    if(m1.symm_=='G' && m2.symm_=='G'){
      dgemm_(&m1.trans_,&m2.trans_,&m1.rows_,&m2.cols_,&m1.cols_,&alpha,m1.data_,
             &m1.rows_,m2.data_,&m1.cols_,&beta,prod->data_,&m2.cols_);
    } else if(m1.symm_=='S'){
      char SIDE = 'L';
      char UPLO = 'L';
      dsymm_(&SIDE,&UPLO,&m1.rows_,&m2.cols_,&alpha,m1.data_,&m1.rows_,m2.data_,
             &m2.cols_,&beta,prod->data_,&m2.cols_);
    } else if(m2.symm_=='S'){
      char SIDE = 'R';
      char UPLO = 'L';
      dsymm_(&SIDE,&UPLO,&m2.rows_,&m1.cols_,&alpha,m2.data_,&m2.rows_,m1.data_,
             &m1.cols_,&beta,prod->data_,&m1.cols_);
    };
  }
  return prod;
};

然后我利用

Matrix *A, *B, *C;
// def of A and B
C = (*A)*(*B);

这工作得很好。我遇到的问题是每次执行此操作时都必须分配一个新矩阵。我希望能够做的是分配一次C矩阵,然后将AB的乘积放入C的内部存储中(C->data_(。从我能够在运算符重载上找到的内容来看,我找不到一种很好的方法来做到这一点。我知道我可以使用成员函数来执行此操作,(即 C->mult(A,B) (,但如果可能的话,我想避免这种情况(我正在编写代码是为了便于非 CSE 类型的开发(。任何想法将不胜感激。

class Matrix
{
    struct Product
    {
        const Matrix* a; 
        const Matrix* b;
    };
    Matrix& operator = (const Product& p)
    {
        // if this matrix dims differ from required by product of p.a and p.b
        // reallocate it first and set dims
        // {     
            // rows = ....; cols = ....;
            // delete [] data; 
            // data = new [rows*cols];
        // }

        // then calculate product
        // data[0] = ...;
        // ...
        return *this;
    }
    Product operator * (const Matrix& op) const
    {
        Product p;
        p.a = this;
        p.b = &op;
        return p;
    }
    int rows,cols;
    double* data;
    /// your Matrix stuff
    // ...
};
void test()
{
    Matrix a(4,2),b(2,4),c;
    c = a * b; // (note a*b returns Product without calculating its result
               // result is calculated inside = operator
    c = a * b; // note that this time c is initialized to correct size so no
               // additional reallocations will occur
}