比较两个向量<bool>与SSE的关系

Comparing two vector<bool> with SSE

本文关键字:gt bool SSE 关系 lt 两个 向量 比较      更新时间:2023-10-16

我有两个vector<bool> A和B。

我想比较它们,并计算相等的元素数量:

例如:

A = {0,1,0,1}
B = {0,0,1,1}

结果将等于2。

我可以使用_mm_cmpeq_epi8,但它只比较16个元素(即,我应该将0和1转换为char,然后进行比较)。是否可以每次将128个元素与SSE(或SIMD指令)进行比较?

如果您可以假设vector<bool>使用连续字节大小的元素进行存储,或者您可以考虑使用类似vector<uint8_t>的元素,那么这个例子应该为您提供一个很好的起点:

static size_t count_equal(const vector<uint8_t> &vec1, const vector<uint8_t> &vec2)
{
    assert(vec1.size() == vec2.size());         // vectors must be same size
    const size_t n = vec1.size();
    const size_t max_block_size = 255 * 16;     // max block size before possible overflow
    __m128i vcount = _mm_setzero_si128();
    size_t i, count = 0;
    for (i = 0; i + 16 <= n; )                  // for each block
    {
        size_t m = std::min(n, i + max_block_size);
        for ( ; i + 16 <= m; i += 16)           // for each vector in block
        {
            __m128i v1 = _mm_loadu_si128((__m128i *)&vec1[i]);
            __m128i v2 = _mm_loadu_si128((__m128i *)&vec2[i]);
            __m128i vcmp = _mm_cmpeq_epi8(v1, v2);
            vcount = _mm_sub_epi8(vcount, vcmp);
        }
        vcount = _mm_sad_epu8(vcount, _mm_setzero_si128());
        count += _mm_extract_epi16(vcount, 0) + _mm_extract_epi16(vcount, 4);
        vcount = _mm_setzero_si128();           // update count from current block
    }
    vcount = _mm_sad_epu8(vcount, _mm_setzero_si128());
    count += _mm_extract_epi16(vcount, 0) + _mm_extract_epi16(vcount, 4);
    for ( ; i < n; ++i)                         // deal with any remaining partial vector
    {
        count += (vec1[i] == vec2[i]);
    }
    return count;
}

请注意,这是使用vector<uint8_t>。如果您真的必须使用vector<bool>,并且可以保证元素始终是连续的和字节大小的,那么您只需要以某种方式将vector<bool>强制为const uint8_t *或类似的内容。

测试线束:

#include <cassert>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <vector>
#include <emmintrin.h>    // SSE2
using std::vector;
static size_t count_equal_ref(const vector<uint8_t> &vec1, const vector<uint8_t> &vec2)
{
    assert(vec1.size() == vec2.size());
    const size_t n = vec1.size();
    size_t i, count = 0;
    for (i = 0 ; i < n; ++i)
    {
        count += (vec1[i] == vec2[i]);
    }
    return count;
}
static size_t count_equal(const vector<uint8_t> &vec1, const vector<uint8_t> &vec2)
{
    assert(vec1.size() == vec2.size());         // vectors must be same size
    const size_t n = vec1.size();
    const size_t max_block_size = 255 * 16;     // max block size before possible overflow
    __m128i vcount = _mm_setzero_si128();
    size_t i, count = 0;
    for (i = 0; i + 16 <= n; )                  // for each block
    {
        size_t m = std::min(n, i + max_block_size);
        for ( ; i + 16 <= m; i += 16)           // for each vector in block
        {
            __m128i v1 = _mm_loadu_si128((__m128i *)&vec1[i]);
            __m128i v2 = _mm_loadu_si128((__m128i *)&vec2[i]);
            __m128i vcmp = _mm_cmpeq_epi8(v1, v2);
            vcount = _mm_sub_epi8(vcount, vcmp);
        }
        vcount = _mm_sad_epu8(vcount, _mm_setzero_si128());
        count += _mm_extract_epi16(vcount, 0) + _mm_extract_epi16(vcount, 4);
        vcount = _mm_setzero_si128();           // update count from current block
    }
    vcount = _mm_sad_epu8(vcount, _mm_setzero_si128());
    count += _mm_extract_epi16(vcount, 0) + _mm_extract_epi16(vcount, 4);
    for ( ; i < n; ++i)                         // deal with any remaining partial vector
    {
        count += (vec1[i] == vec2[i]);
    }
    return count;
}
int main(int argc, char * argv[])
{
    size_t n = 100;
    if (argc > 1)
    {
        n = atoi(argv[1]);
    }
    vector<uint8_t> vec1(n);
    vector<uint8_t> vec2(n);
    srand((unsigned int)time(NULL));
    for (size_t i = 0; i < n; ++i)
    {
        vec1[i] = rand() & 1;
        vec2[i] = rand() & 1;
    }
    size_t n_ref = count_equal_ref(vec1, vec2);
    size_t n_test = count_equal(vec1, vec2);
    if (n_ref == n_test)
    {
        std::cout << "PASS" << std::endl;
    }
    else
    {
        std::cout << "FAIL: n_ref = " << n_ref << ", n_test = " << n_test << std::endl;
    }
    return 0;
}

编译并运行:

$ g++ -Wall -msse3 -O3 test.cpp && ./a.out
PASS

std::vector<bool>是类型boolstd::vector的专门化。虽然没有由C++标准指定,但在大多数实现中,std::vector<bool>是空间有效的,使得它的每个元素是单个比特而不是bool

std::vector<bool>的行为与其主要模板对应物相似,不同之处在于:

  1. CCD_ 15不一定连续地存储其元素
  2. 为了暴露其元素(即,各个比特),std::vector<bool>使用代理类(即,std::vector<bool>::reference)。class std::vector<bool>::reference的对象由std::vector<bool>下标运算符(即operator[])按值返回

因此,我不认为使用类似_mm_cmpeq_epi8的函数是可移植的,因为std::vector<bool>的存储是由实现定义的(即,不保证连续)。

另一种可移植的方法是使用常规STL设施,如下面的示例:

std::vector<bool> A = {0,1,0,1};
std::vector<bool> B = {0,0,1,1};
std::vector<bool> C(A.size());
std::transform(A.begin(), A.end(), B.begin(), C.begin(), [](bool const &a, bool const &b) { return a == b;});
std::cout << std::count(C.begin(), C.end(), true) << std::endl;

实时演示