是否可以在C++中创建泛型函数指针?

Can I create generic function pointers in C++?

本文关键字:泛型 函数 指针 创建 C++ 是否      更新时间:2023-10-16

我有以下方法:

Matrix& relu(float threshold, Matrix& M)
Matrix& softmax(Matrix& M)

我想要一个可以接收这两种方法中的任何一种的函数指针,这在C++中甚至可能吗? 如果没有,是否有任何优雅的解决方法?

谢谢

有一个运行时解决方案,它会慢一点,你必须检查它是哪种类型

#include <iostream>
#include <functional>
#include <variant>
struct Matrix {
float test{0};
};
Matrix& relu(float threshold, Matrix& M) {
M.test = threshold;
return M;
}
Matrix& softmax(Matrix& M) {
M.test = 42;
return M;
}
using matrix_generic_function_pointer_t =
std::variant<std::function<Matrix&(Matrix&)>, std::function<Matrix& (float, Matrix&)>>;
void use_function(matrix_generic_function_pointer_t function_pointer) {
Matrix m{}; float something = 1337;
if (function_pointer.index() == 0) {
std::cout << "Softmax(): " << std::get<std::function<Matrix&(Matrix&)>>(function_pointer)(m).test;
} else {
std::cout << "Relu(): " << std::get<std::function<Matrix& (float, Matrix&)>>(function_pointer)(something,m).test;
}
}
int main()
{
matrix_generic_function_pointer_t func1{&relu};
matrix_generic_function_pointer_t func2{&softmax};
use_function(func1);
use_function(func2);
std::cout << "n";
}

还有一个编译时版本,使用起来更快、更安全,因为它不会在运行时失败,而是生成编译器错误,但对于您的代码来说有点不灵活:

#include <iostream>
#include <functional>
struct Matrix {
float test{0};
};
Matrix& relu(float threshold, Matrix& M) {
M.test = threshold;
return M;
}
Matrix& softmax(Matrix& M) {
M.test = 42;
return M;
}
template <typename F>
void use_function(F&& function) {
Matrix m{};
std::cout << function(m).test << "n";
}
int main()
{
use_function(std::bind(&relu, 1337, std::placeholders::_1));
use_function(&softmax);
std::cout << "n";
}

从技术上讲,您可以使用任何函数指针指向这两个函数。只需显式转换类型。当然,如果类型不匹配,您将无法通过指针调用。

另一种选择是将函数指针存储在类型擦除包装器(如std::anystd::variant(中,这些包装器提供了可能有用的访问机制。

因此,您确实有办法指向不同的功能,但另一个问题是这样做是否有用。

您可以使用泛型函数包装器std::function

Matrix& relu(float threshold, Matrix& M);
Matrix& softmax(Matrix& M);
// Accepts a reference to Matrix, returns a reference to Matrix.
// It probably modifies the input matrix in-place.
using F = std::function<Matrix&(Matrix& M)>; 
int main() {
F f1{softmax};
float threshold = 0.1; // threshold gets copied into the lambda capture below.
F f2{[threshold](Matrix& M) -> Matrix& { return relu(threshold, M); }};
Matrix m;
Matrix& m2 = f1(m);
Matrix& m3 = f2(m);
}

您的激活函数relusoftmax可能应该接受对 const 向量的引用,并按值返回新向量。