通过参考变差函数传递N-D数组

Pass N-D array by reference to variadic function

本文关键字:N-D 数组 函数 参考      更新时间:2023-10-16

我想让函数multi_dimensional通过引用接受一个多维数组。

这可以通过以下适用于three_dimensional的语法变体来实现吗?

#include <utility>
// this works, but number of dimensions must be known (not variadic)
template <size_t x, size_t y, size_t z>
void three_dimensional(int (&nd_array)[x][y][z]) {}
// error: parameter packs not expanded with ‘...’
template <size_t... dims>
void multi_dimensional(int (&nd_array)[dims]...) {}
int main() {
int array[2][3][2] = {
{ {0,1}, {2,3}, {4,5} },
{ {6,7}, {8,9}, {10,11} }
};
three_dimensional(array); // OK
// multi_dimensional(array); // error: no matching function
return 0;
}

主要问题是不能使数组维度本身的数量可变。因此,无论采用哪种方式,几乎肯定都需要某种递归方法来处理各个数组层。这种方法到底应该是什么样子,主要取决于一旦数组被提供给您,您计划对它做什么。

如果你真正想要的是一个可以被赋予任何多维数组的函数,那么就写一个可以赋予任何东西但只要任何东西是数组就存在的函数:

template <typename T>
std::enable_if_t<std::is_array_v<T>> multi_dimensional(T& a)
{
constexpr int dimensions = std::rank_v<T>;
// ...
}

然而,这本身很可能不会让你走得很远。要真正对给定的数组执行任何有意义的操作,您很可能需要一些递归遍历子数组。除非你真的只是想看看结构的最顶层。

另一种方法是使用递归模板剥离各个阵列级别,例如:

// we've reached the bottom
template <typename T, int N>
void multi_dimensional(T (&a)[N])
{
// ...
}
// this matches any array with more than one dimension
template <typename T, int N, int M>
void multi_dimensional(T (&a)[N][M])
{
// peel off one dimension, invoke function for each element on next layer
for (int i = 0; i < N; ++i)
multi_dimensional(a[i]);
}

然而,我建议至少考虑使用std::array<>而不是原始数组,因为原始数组的语法和特殊行为往往会很快将一切都变成混乱的局面。一般来说,实现自己的多维数组类型可能是值得的,比如NDArray<int, 2, 3, 2>,它在内部使用扁平表示,只将多维索引映射到线性索引。这种方法的一个优点(除了更干净的语法之外(是,您可以轻松地更改映射,例如,从行主布局切换到列主布局,例如,用于性能优化…

为了实现一个具有静态维度的通用nD数组,我将引入一个助手类来封装nD索引的线性索引的递归计算:

template <std::size_t... D>
struct row_major;
template <std::size_t D_n>
struct row_major<D_n>
{
static constexpr std::size_t SIZE = D_n;
std::size_t operator ()(std::size_t i_n) const
{
return i_n;
}
};
template <std::size_t D_1, std::size_t... D_n>
struct row_major<D_1, D_n...> : private row_major<D_n...>
{
static constexpr std::size_t SIZE = D_1 * row_major<D_n...>::SIZE;
template <typename... Tail>
std::size_t operator ()(std::size_t i_1, Tail&&... tail) const
{
return i_1 + D_1 * row_major<D_n...>::operator ()(std::forward<Tail>(tail)...);
}
};

然后:

template <typename T, std::size_t... D>
class NDArray
{
using memory_layout_t = row_major<D...>;
T data[memory_layout_t::SIZE];
public:
template <typename... Args>
T& operator ()(Args&&... args)
{
memory_layout_t memory_layout;
return data[memory_layout(std::forward<Args>(args)...)];
}
};

NDArray<int, 2, 3, 5> arr;
int main()
{
int x = arr(1, 2, 3);
}