有效使用enable_if C++模板以避免类专业化

Effective use of enable_if with C++ templates to avoid class specialization

本文关键字:专业化 C++ enable if 有效      更新时间:2023-10-16

我在编译代码时遇到问题。 CLANG,G++和ICPC都给出不同的错误消息,

在进入问题本身之前,先了解一下背景知识:

我现在正在研究用于处理矩阵的模板类层次结构。 数据类型(浮点型或双精度型)和"实施策略"都有模板参数,目前包括带循环的常规C++代码和英特尔 MKL 版本。 以下是删节摘要(请忽略此处缺少前瞻性参考等 - 这与我的问题无关):

// Matrix.h
template <typename Type, typename IP>
class Matrix : public Matrix_Base<Type, IP>;
template <typename Matrix_Type>
class Matrix_Base{
    /* ... */
    // Matrix / Scalar addition
    template <typename T>
    Matrix_Base& operator+=(const T value_) { 
      return Implementation<IP>::Plus_Equal(
          static_cast<Matrix_Type&>(*this), value_);
    /* More operators and rest of code... */
    };
struct CPP;
struct MKL;
template <typename IP>
struct Implementation{
/* This struct contains static methods that do the actual operations */

我现在遇到的麻烦与实现类的实现有关(没有双关语)。 我知道我可以使用实现模板类的专用化来专门化template <> struct Implementation<MKL>{/* ... */};但是,这将导致大量代码重复,因为有许多运算符(例如矩阵标量加法、减法...... )泛型和专用版本使用相同的代码。

因此,相反,我认为我可以摆脱模板专用化,只使用enable_if为那些在使用 MKL(或 CUDA 等)时具有不同实现的运算符提供不同的实现。

事实证明,这比我最初预期的更具挑战性。 第一个 - 对于operator += (T value_)工作正常。 我添加了检查只是为了确保参数是合理的(如果它是我麻烦的根源,这可以消除,我对此表示怀疑)。

template <class Matrix_Type, typename Type, typename enable_if< 
    std::is_arithmetic<Type>::value  >::type* dummy = nullptr>
static Matrix_Type& Plus_Equal(Matrix_Type& matrix_, Type value_){
    uint64_t total_elements = matrix_.actual_dims.first * matrix_.actual_dims.second;
    //y := A + b
    #pragma parallel 
    for (uint64_t i = 0; i < total_elements; ++i)
        matrix_.Data[i] += value_; 
    return matrix_;
}

但是,我很难弄清楚如何处理operator *=(T value_)。 这是因为floatdouble对 MKL 有不同的实现,但在一般情况下并非如此。

这是声明。 请注意,第三个参数是一个虚拟参数,是我试图强制函数重载的尝试,因为我无法使用部分模板函数专用化:

template <class Matrix_Type, typename U, typename Type = 
    typename internal::Type_Traits< Matrix_Type>::type, typename  enable_if<
    std::is_arithmetic<Type>::value >::type* dummy = nullptr>
static Matrix_Type& Times_Equal(Matrix_Type& matrix_, U value_, Type dummy_ = 0.0);

一般情况的定义。 :

template<class IP>
template <class Matrix_Type, typename U, typename Type,  typename enable_if<
    std::is_arithmetic<Type>::value >::type* dummy>
Matrix_Type& Implementation<IP>::Times_Equal(Matrix_Type& matrix_, U value_, Type){
    uint64_t total_elements = matrix_.actual_dims.first * matrix_.actual_dims.second;
    //y := A - b
    #pragma parallel
    for (uint64_t i = 0; i < total_elements; ++i)
        matrix_.Data[i] *= value_;
    return matrix_;
}

当我尝试实现 MKL 的专业化时,麻烦就开始了:

template<>
template <class Matrix_Type, typename U, typename Type, typename enable_if<
    std::is_arithmetic<Type>::value >::type* dummy>
Matrix_Type& Implementation<implementation::MKL>::Times_Equal(
    Matrix_Type& matrix_, 
    U value_,
    typename enable_if<std::is_same<Type,float>::value,Type>::type)
{
    float value = value_;
    MKL_INT total_elements = matrix_.actual_dims.first * matrix_.actual_dims.second;
    MKL_INT const_one = 1;
    //y := a * b
    sscal(&total_elements, &value, matrix_.Data, &const_one);
    return matrix_;
}

这给了我一个叮当的错误:

_错误:"Times_Equal"的外联定义与"实现"中的任何声明都不匹配_

在 G++ 中(略有缩短)

_error: 'Matrix_Type& Implementation::Times_Equal(...)' 的模板 ID 'Times_Equal<>' 与任何模板声明都不匹配。

如果我将第三个参数更改为 Type,而不是enable_if,代码编译得很好。 但是当我这样做时,我看不到如何为 float 和 double 提供单独的实现。

任何帮助将不胜感激。

我认为使用

std::enable_if 实现这将非常乏味,因为一般情况必须使用enable_if来实现,如果它不适合其中一个专业,则打开它。

具体到你的代码,我认为编译器无法推断出你的MKL专业化中的Type,因为它隐藏在std::enable_if中,因此这种专业化永远不会被调用。

与其使用enable_if,不如执行以下操作:

#include<iostream>
struct CPP {};
struct MKL {};
namespace Implementation
{
   //
   // general Plus_Equal
   //
   template<class Type, class IP>
   struct Plus_Equal
   {
      template<class Matrix_Type>
      static Matrix_Type& apply(Matrix_Type& matrix_, Type value_)
      {
         std::cout << " Matrix Plus Equal General Implementation " << std::endl;
         // ... do general Plus_Equal ...
         return matrix_;
      }
   };
   //
   // specialized Plus_Equal for MKL with Type double
   //
   template<>
   struct Plus_Equal<double,MKL>
   {
      template<class Matrix_Type>
      static Matrix_Type& apply(Matrix_Type& matrix_, double value_)
      {
         std::cout << " Matrix Plus Equal MKL double Implementation " << std::endl;
         // ... do MKL/double specialized Plus_Equal ...
         return matrix_;
      }
   };
} // namespace Implementation
template <typename Type, typename IP, typename Matrix_Type>
class Matrix_Base
{  
   public:
   // ... matrix base implementation ...
   // Matrix / Scalar addition
   template <typename T>
   Matrix_Base& operator+=(const T value_) 
   { 
      return Implementation::Plus_Equal<Type,IP>::apply(static_cast<Matrix_Type&>(*this), value_);
   }
   // ...More operators and rest of code...
};
template <typename Type, typename IP>
class Matrix : public Matrix_Base<Type, IP, Matrix<Type,IP> >
{
   // ... Matrix implementation ...
};
int main()
{
   Matrix<float ,MKL> f_mkl_mat;
   Matrix<double,MKL> d_mkl_mat;
   f_mkl_mat+=2.0; // will use general plus equal
   d_mkl_mat+=2.0; // will use specialized MKL/double version
   return 0;
}

在这里,我使用了类专用化而不是 std::enable_if。我发现您与示例中的IPTypeMatrix_Type类型非常不一致,所以我希望我在这里正确使用它们。

作为关于std::enable_if的评论的旁白。我会使用表格

template<... , typename std::enable_if< some bool >::type* = nullptr> void func(...);

template<... , typename = std::enable_if< some bool >::type> void func(...);

因为它使您能够执行一些其他表单无法完成的功能重载。

希望你能用:)

编辑20/12-13:重新阅读我的帖子后,我发现我应该明确地做CRTP(奇怪的重复模板模式),我在上面的代码中添加了它。我将TypeIP都传递给Matrix_Base.如果你觉得这很乏味,可以提供一个矩阵特征类,Matrix_Base可以从中取出它们。

template<class A>
struct Matrix_Traits;
// Specialization for Matrix class
template<class Type, class IP>
struct Matrix_Traits<Matrix<Type,IP> >
{
   using type = Type;
   using ip   = IP;
};

然后Matrix_Base现在只接受一个模板参数,即矩阵类本身,并从 traits 类中获取类型

template<class Matrix_Type>
class Matrix_Base
{
   // Matrix / Scalar addition
   template <typename T>
   Matrix_Base& operator+=(const T value_) 
   { 
      // We now get Type and IP from Matrix_Traits
      return Implementation::Plus_Equal<typename Matrix_Traits<Matrix_Type>::type
                                      , typename Matrix_Traits<Matrix_Type>::ip
                                      >::apply(static_cast<Matrix_Type&>(*this), value_);
   }
};