C++矩阵乘法类型检测

C++ Matrix multiplication type detection

本文关键字:类型 检测 C++      更新时间:2023-10-16

在我的C++代码中,我有一个矩阵类,以及一些编写的运算符来乘以它们。我的类是模板化的,这意味着我可以有 int、float、double ...矩阵。

我猜我的运算符重载是经典

    template <typename T, typename U>
    Matrix<T>& operator*(const Matrix<T>& a, const Matrix<U>& b)
    {
    assert(a.rows() == b.cols() && "You have to multiply a MxN matrix with a NxP one to get a MxP matrixn");
    Matrix<T> *c = new Matrix<T>(a.rows(), b.cols());
    for (int ci=0 ; ci<c->rows() ; ++ci)
    {
      for (int cj=0 ; cj<c->cols() ; ++cj)
      {
        c->at(ci,cj)=0;
        for (int k=0 ; k<a.cols() ; ++k)
        {
          c->at(ci,cj) += (T)(a.at(ci,k)*b.at(k,cj));
        }
      }
    }
    return *c;
  }

在此代码中,我返回与第一个参数类型相同的矩阵,即 Matrix<int> * Matrix<float> = Matrix<int> .我的问题是我如何检测我给出的两个中最精确的类型,以免损失太多精度,即具有Matrix<int> * Matrix<float> = Matrix<float>?有没有聪明的人可以做到这一点?

你想要的只是将T乘以U时发生的类型。这可以通过以下方式给出:

template <class T, class U>
using product_type = decltype(std::declval<T>() * std::declval<U>());

您可以将其用作额外的默认模板参数:

template <typename T, typename U, typename R = product_type<T, U>>
Matrix<R> operator*(const Matrix<T>& a, const Matrix<U>& b) {
    ...
}
<小时 />

在 C++03 中,您可以通过使用许多小的帮助程序类型执行一系列巨大的重载来完成同样的事情(这就是 Boost 的做法):

template <int I> struct arith;
template <int I, typename T> struct arith_helper {
    typedef T type;
    typedef char (&result_type)[I];
};
template <> struct arith<1> : arith_helper<1, bool> { };
template <> struct arith<2> : arith_helper<2, bool> { };
template <> struct arith<3> : arith_helper<3, signed char> { };
template <> struct arith<4> : arith_helper<4, short> { };
// ... lots more

然后我们可以写:

template <class T, class U>
class common_type {
private:
    static arith<1>::result_type select(arith<1>::type );
    static arith<2>::result_type select(arith<2>::type );
    static arith<3>::result_type select(arith<3>::type );
    // ...
    static bool cond();
public:
    typedef typename arith<sizeof(select(cond() ? T() : U() ))>::type type;
};

假设你写出了所有的整数类型,那么你可以使用typename common_type<T, U>::type在我使用之前product_type.

如果这不是C++11有多酷的证明,我不知道是什么。

<小时 />

请注意,operator*不应返回引用。您正在执行的操作会泄漏内存。