使用SSE和AVX查找矩阵中最大的元素及其列和行索引

Find largest element in matrix and its column and row indexes using SSE and AVX

本文关键字:元素 索引 AVX SSE 查找 使用      更新时间:2023-10-16

我需要找到1d矩阵中最大的元素及其列和行索引。

我使用1d矩阵,所以只需要先找到max元素的索引,然后就可以很容易地获得行和列。

我的问题是我不能得到那个索引。

我有一个工作函数,它可以找到最大的元素并使用SSE,这里是:

float find_largest_element_in_matrix_SSE(float* m, unsigned const int dims)
{
    size_t i;
    int index = -1;
    __m128 max_el = _mm_loadu_ps(m);
    __m128 curr;
    for (i = 4; i < dims * dims; i += 4)
    {
        curr = _mm_loadu_ps(m + i);
        max_el = _mm_max_ps(max_el, curr);
    }
    __declspec(align(16))float max_v[4] = { 0 };
    _mm_store_ps(max_v, max_el);
    return max(max(max(max_v[0], max_v[1]), max_v[2]), max_v[3]);
}

我还有一个使用AVX:的非工作功能

float find_largest_element_in_matrix_AVX(float* m, unsigned const int dims)
{
    size_t i;
    int index = -1;
    __m256 max_el = _mm256_loadu_ps(m);
    __m256 curr;
    for (i = 8; i < dims * dims; i += 8)
    {
        curr = _mm256_loadu_ps(m + i);
        max_el = _mm256_max_ps(max_el, curr);
    }
    __declspec(align(32))float max_v[8] = { 0 };
    _mm256_store_ps(max_v, max_el);
    __m256 y = _mm256_permute2f128_ps(max_el, max_el, 1);
    __m256 m1 = _mm256_max_ps(max_el, y);m1[1] = max(max_el[1], max_el[3])
    __m256 m2 = _mm256_permute_ps(m1, 5); 
    __m256 m_res = _mm256_max_ps(m1, m2); 
    return m[0];
}

有人能帮我找到最大元素的索引并使我的AVX版本正常工作吗?

这里有一个工作的SSE(SSE 4)实现,它返回最大值和相应的索引,以及标量参考实现和测试工具:

#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <time.h>
#include <smmintrin.h>  // SSE 4.1
float find_largest_element_in_matrix_ref(const float* m, int dims, int *maxIndex)
{
    float maxVal = m[0];
    int i;
    *maxIndex = 0;
    for (i = 1; i < dims * dims; ++i)
    {
        if (m[i] > maxVal)
        {
            maxVal = m[i];
            *maxIndex = i;
        }
    }
    return maxVal;
}
float find_largest_element_in_matrix_SSE(const float* m, int dims, int *maxIndex)
{
    float maxVal = m[0];
    float aMaxVal[4];
    int32_t aMaxIndex[4];
    int i;
    *maxIndex = 0;
    const __m128i vIndexInc = _mm_set1_epi32(4);
    __m128i vMaxIndex = _mm_setr_epi32(0, 1, 2, 3);
    __m128i vIndex = vMaxIndex;
    __m128 vMaxVal = _mm_loadu_ps(m);
    for (i = 4; i < dims * dims; i += 4)
    {
        __m128 v = _mm_loadu_ps(&m[i]);
        __m128 vcmp = _mm_cmpgt_ps(v, vMaxVal);
        vIndex = _mm_add_epi32(vIndex, vIndexInc);
        vMaxVal = _mm_max_ps(vMaxVal, v);
        vMaxIndex = _mm_blendv_epi8(vMaxIndex, vIndex, _mm_castps_si128(vcmp));
    }
    _mm_storeu_ps(aMaxVal, vMaxVal);
    _mm_storeu_si128((__m128i *)aMaxIndex, vMaxIndex);
    maxVal = aMaxVal[0];
    *maxIndex = aMaxIndex[0];
    for (i = 1; i < 4; ++i)
    {
        if (aMaxVal[i] > maxVal)
        {
            maxVal = aMaxVal[i];
            *maxIndex = aMaxIndex[i];
        }
    }
    return maxVal;
}
int main()
{
    const int dims = 1024;
    float m[dims * dims];
    float maxVal_ref, maxVal_SSE;
    int maxIndex_ref = -1, maxIndex_SSE = -1;
    int i;
    srand(time(NULL));
    for (i = 0; i < dims * dims; ++i)
    {
        m[i] = (float)rand() / RAND_MAX;
    }
    maxVal_ref = find_largest_element_in_matrix_ref(m, dims, &maxIndex_ref);
    maxVal_SSE = find_largest_element_in_matrix_SSE(m, dims, &maxIndex_SSE);
    if (maxVal_ref == maxVal_SSE && maxIndex_ref == maxIndex_SSE)
    {
        printf("PASS: maxVal = %f, maxIndex = %dn",
                      maxVal_ref, maxIndex_ref);
    }
    else
    {
        printf("FAIL: maxVal_ref = %f, maxVal_SSE = %f, maxIndex_ref = %d, maxIndex_SSE = %dn",
                      maxVal_ref, maxVal_SSE, maxIndex_ref, maxIndex_SSE);
    }
    return 0;
}

编译并运行:

$ gcc -Wall -msse4 Yakovenko.c && ./a.out 
PASS: maxVal = 0.999999, maxIndex = 120409

显然,如果需要,您可以获得行和列索引:

int rowIndex = maxIndex / dims;
int colIndex = maxIndex % dims;

从这里开始,编写AVX2实现应该相当简单。

一种方法是在第一次遍历中计算最大值,在第二次遍历中通过线性搜索找到索引

#define anybit __builtin_ctz   //or lookup table with 16 entries...
float find_largest_element_in_matrix_SSE(const float* m, int dims, int *maxIndex) {
    //first pass: calculate maximum as usual
    __m128 vMaxVal = _mm_loadu_ps(m);
    for (int i = 4; i < dims * dims; i += 4)
        vMaxVal = _mm_max_ps(vMaxVal, _mm_loadu_ps(&m[i]));
    //perform in-register reduction
    vMaxVal = _mm_max_ps(vMaxVal, _mm_shuffle_ps(vMaxVal, vMaxVal, _MM_SHUFFLE(2, 3, 0, 1)));
    vMaxVal = _mm_max_ps(vMaxVal, _mm_shuffle_ps(vMaxVal, vMaxVal, _MM_SHUFFLE(1, 0, 3, 2)));
    //second pass: search for maximal value
    for (int i = 0; i < dims * dims; i += 4) {
        __m128 vIsMax = _mm_cmpeq_ps(vMaxVal, _mm_loadu_ps(&m[i]));
        if (int mask = _mm_movemask_ps(vIsMax)) {
            *maxIndex = i + anybit(mask);
            return _mm_cvtss_f32(vMaxVal);
        }
    }
}

请注意,除非您的输入数据非常小,否则第二个循环中的分支应该几乎完全被预测。

该解决方案存在几个问题,特别是:

  1. 它可能在存在奇怪的浮点值时工作不正确,例如使用NaN。

  2. 如果您的矩阵不适合CPU缓存,那么代码将从主存中读取矩阵两次,因此它将比单程方法慢两倍。对于大型矩阵,可以通过逐块处理来解决这一问题。

  3. 在第一个循环中,每次迭代都取决于前一次迭代(vMaxVal既被修改又被读取),因此它将因_mm_max_ps的延迟而减慢。也许最好将第一个循环展开一位(2x或4x),同时为vMaxVal提供4个独立的寄存器(实际上,第二个循环也将受益于展开)

移植到AVX应该是非常直接的,除了寄存器内减少:

vMaxVal = _mm256_max_ps(vMaxVal, _mm256_shuffle_ps(vMaxVal, vMaxVal, _MM_SHUFFLE(2, 3, 0, 1)));
vMaxVal = _mm256_max_ps(vMaxVal, _mm256_shuffle_ps(vMaxVal, vMaxVal, _MM_SHUFFLE(1, 0, 3, 2)));
vMaxVal = _mm256_max_ps(vMaxVal, _mm256_permute2f128_ps(vMaxVal, vMaxVal, 1));

另一种方法:

void find_largest_element_in_matrix_SSE(float * matrix, size_t n, int * row, int * column, float * v){
    __m128 indecies = _mm_setr_ps(0, 1, 2, 3);
    __m128 update = _mm_setr_ps(4, 4, 4, 4);
    __m128 max_indecies = _mm_setr_ps(0, 1, 2, 3);
    __m128 max = _mm_load_ps(matrix);
    for (int i = 4; i < n * n; i+=4){
        indecies = _mm_add_ps(indecies, update);
        __m128 pm2 = _mm_load_ps(&matrix[i]);
        __m128 mask = _mm_cmpge_ps(max, pm2);
        max = _mm_max_ps(max, pm2);
        max_indecies = _mm_or_ps(_mm_and_ps(max_indecies, mask), _mm_andnot_ps(mask, indecies));
    }
    __declspec (align(16)) int max_ind[4];
    __m128i maxi = _mm_cvtps_epi32(max_indecies);
    _mm_store_si128((__m128i *) max_ind, maxi);
    int c = max_ind[0];
    for (int i = 1; i < 4; i++)
        if (matrix[max_ind[i]] >= matrix[c] && max_ind[i] < c){
            c = max_ind[i];
        }
    *v = matrix[c];
    *row = c / n;
    *column = c % n;
}
void find_largest_element_in_matrix_AVX(float * matrix, size_t n, int * row,  int * column, float * v){
    __m256 indecies = _mm256_setr_ps(0, 1, 2, 3, 4, 5, 6, 7);
    __m256 update = _mm256_setr_ps(8, 8, 8, 8, 8, 8, 8, 8);
    __m256 max_indecies = _mm256_setr_ps(0, 1, 2, 3, 4, 5, 6, 7);
    __m256 max = _mm256_load_ps(matrix);
    for (int i = 8; i < n * n; i += 8){
        indecies = _mm256_add_ps(indecies, update);
        __m256 pm2 = _mm256_load_ps(&matrix[i]);
        __m256 mask = _mm256_cmp_ps(max, pm2, _CMP_GE_OQ);
        max = _mm256_max_ps(max, pm2);
        max_indecies = _mm256_or_ps(_mm256_and_ps(max_indecies, mask), _mm256_andnot_ps(mask, indecies));
    }
    __declspec (align(32)) int max_ind[8];
    __m256i maxi = _mm256_cvtps_epi32(max_indecies);
    _mm256_store_si256((__m256i *) max_ind, maxi);
    int c = max_ind[0];
    for (int i = 1; i < 8; i++)
        if (matrix[max_ind[i]] >= matrix[c] && max_ind[i] < c){
            c = max_ind[i];
        }
    *v = matrix[c];
    *row = c / n;
    *column = c % n;
}