AVX2列填充计数算法分别针对每个位列

AVX2 column population count algorithm over each bit-column separately

本文关键字:填充 算法 AVX2      更新时间:2023-10-16

对于我正在进行的项目,我需要计算撕裂的PDF图像数据中每列的设置位数。

我试图获得整个PDF作业(所有页面)中每列的总设置位计数。

数据一旦被剥离,就存储在MemoryMappedFile中,没有备份文件(在内存中)。

PDF页面尺寸为13952像素x 15125像素。可以通过将以像素为单位的PDF的长度(高度)乘以以字节为单位的宽度来计算得到的撕裂数据的总大小。被撕裂的数据是1 bit == 1 pixel。因此,以字节为单位的撕裂页的大小是(13952 / 8) * 15125

注意,宽度总是64 bits的倍数。

在被撕扯后,我必须计算PDF(可能是数万页)每页中每列的设置位。

我首先用一个基本的解决方案来解决这个问题,只需循环遍历每个字节,计算设置位的数量,并将结果放入vector中。从那以后,我将算法简化为如下所示。我的执行时间从大约350毫秒增加到大约120毫秒。

static void count_dots( )
{
using namespace diag;
using namespace std::chrono;
std::vector<std::size_t> dot_counts( 13952, 0 );
uint64_t* ptr_dot_counts{ dot_counts.data( ) };
std::vector<uint64_t> ripped_pdf_data( 3297250, 0xFFFFFFFFFFFFFFFFUL );
const uint64_t* ptr_data{ ripped_pdf_data.data( ) };
std::size_t line_count{ 0 };
std::size_t counter{ ripped_pdf_data.size( ) };
stopwatch sw;
sw.start( );
while( counter > 0 )
{
*ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0100000000000000UL ) >> 56;
*ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0001000000000000UL ) >> 48;
*ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000010000000000UL ) >> 40;
*ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000100000000UL ) >> 32;
*ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000001000000UL ) >> 24;
*ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000010000UL ) >> 16;
*ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000000100UL ) >> 8;
*ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000000001UL ) >> 0;
*ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0100000000000000UL ) >> 56;
*ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0001000000000000UL ) >> 48;
*ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000010000000000UL ) >> 40;
*ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000100000000UL ) >> 32;
*ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000001000000UL ) >> 24;
*ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000010000UL ) >> 16;
*ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000000100UL ) >> 8;
*ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000000001UL ) >> 0;
*ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0100000000000000UL ) >> 56;
*ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0001000000000000UL ) >> 48;
*ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000010000000000UL ) >> 40;
*ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000100000000UL ) >> 32;
*ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000001000000UL ) >> 24;
*ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000010000UL ) >> 16;
*ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000000100UL ) >> 8;
*ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000000001UL ) >> 0;
*ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0100000000000000UL ) >> 56;
*ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0001000000000000UL ) >> 48;
*ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000010000000000UL ) >> 40;
*ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000100000000UL ) >> 32;
*ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000001000000UL ) >> 24;
*ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000010000UL ) >> 16;
*ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000000100UL ) >> 8;
*ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000000001UL ) >> 0;
*ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0100000000000000UL ) >> 56;
*ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0001000000000000UL ) >> 48;
*ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000010000000000UL ) >> 40;
*ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000100000000UL ) >> 32;
*ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000001000000UL ) >> 24;
*ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000010000UL ) >> 16;
*ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000000100UL ) >> 8;
*ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000000001UL ) >> 0;
*ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0100000000000000UL ) >> 56;
*ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0001000000000000UL ) >> 48;
*ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000010000000000UL ) >> 40;
*ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000100000000UL ) >> 32;
*ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000001000000UL ) >> 24;
*ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000010000UL ) >> 16;
*ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000000100UL ) >> 8;
*ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000000001UL ) >> 0;
*ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0100000000000000UL ) >> 56;
*ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0001000000000000UL ) >> 48;
*ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000010000000000UL ) >> 40;
*ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000100000000UL ) >> 32;
*ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000001000000UL ) >> 24;
*ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000010000UL ) >> 16;
*ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000000100UL ) >> 8;
*ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000000001UL ) >> 0;
*ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0100000000000000UL ) >> 56;
*ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0001000000000000UL ) >> 48;
*ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000010000000000UL ) >> 40;
*ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000100000000UL ) >> 32;
*ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000001000000UL ) >> 24;
*ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000010000UL ) >> 16;
*ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000000100UL ) >> 8;
*ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000000001UL ) >> 0;
++ptr_data;
--counter;
if( ++line_count >= 218 )
{
ptr_dot_counts = dot_counts.data( );
line_count = 0;
}
}   
sw.stop( );
std::cout << sw.elapsed<milliseconds>( ) << "msn";
}

不幸的是,这仍然会增加很多额外的处理时间,这是不可接受的。

上面的代码很难看,不会赢得任何选美比赛,但它有助于减少执行时间。自从我写的原始版本以来,我已经做了以下工作:

  • 使用pointers而不是indexers
  • uint64而不是uint8的块处理数据
  • 手动展开for循环以遍历uint64的每个byte中的每个bit
  • 使用最终的bit shift而不是__popcnt64对屏蔽后的集合bit进行计数

对于这个测试,我生成伪造的撕裂数据,其中每个bit都设置为1。测试完成后,dot_countsvector应包含每个element15125

我希望这里的一些人能帮助我将算法的平均执行时间控制在100毫秒以下。我从来都不在乎这里的便携性。

  • 目标机器的CPU:Xeon E5-2680 v4 - Intel
  • 编译器:MSVC++ 14.23
  • 操作系统:Windows 10
  • C++版本:C++17
  • 编译器标志:/O2/arch:AVX2

大约8年前,有人问过一个非常相似的问题:如何在Sandy Bridge上的一系列int中快速将比特计数到单独的垃圾箱中?

(编者按:也许你错过了在许多64位位掩码上分别计算每个位的位置,使用AVX,而不是AVX2,它有一些最近更快的答案,至少对于向下一列而不是沿着连续内存中的一行。也许你可以向下一列走1或2条缓存线,这样你就可以在SIMD寄存器中保持计数器的热度。)

当我把迄今为止我得到的答案与公认的答案进行比较时,我已经相当接近了。我已经在处理uint64而不是uint8的块。我只是想知道我是否还能做更多的事情,无论是使用内部函数、汇编,还是像更改我使用的数据结构这样简单的事情。

这可以用AVX2完成,如标签所示。

为了正确计算,我建议使用vector<uint16_t>进行计数。增加计数是最大的问题,我们越需要扩大,问题就越大。uint16_t足以计数一页,因此您可以一次计数一页并将计数器添加到一组更宽的计数器中以获得总计。这是一些开销,但比必须在主循环中加宽要少得多。

计数的big-endian排序非常令人讨厌,引入了更多的shuffle以使其正确。因此,我建议将其错误,然后重新排序计数(可能是在将它们相加为总数的过程中?)。"先右移7,然后右移6,然后右移5"的顺序可以免费保持,因为我们可以随心所欲地选择64位块的移位计数。因此,在下面的代码中,计数的实际顺序是:

  • 最低有效字节的比特7
  • 第二个字节的第7位
  • 最高有效字节的比特7
  • 最低有效字节的比特6

所以每组8个是相反的。(至少这是我打算做的,AVX2解包令人困惑)

代码(未测试):

while( counter > 0 )
{
__m256i data = _mm256_set1_epi64x(*ptr_data);        
__m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
__m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));
__m256i zero = _mm256_setzero_si256();
__m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);
c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);
c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);
c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);
ptr_dot_counts += 64;
++ptr_data;
--counter;
if( ++line_count >= 218 )
{
ptr_dot_counts = dot_counts.data( );
line_count = 0;
}
}

这可以进一步展开,一次处理多行。这很好,因为如前所述,向计数器求和是最大的问题,而按行展开将减少这一问题,而在寄存器中进行更简单的求和。

Somme内部使用:

  • _mm256_set1_epi64x将一个int64_t复制到向量的所有4个64位元素。同样适用于uint64_t
  • CCD_ 40将4个64比特的值转换成一个矢量
  • _mm256_srlv_epi64,右移逻辑,具有可变计数(对于每个元素可以是不同的计数)
  • _mm256_and_si256,只是按位AND
  • 此外,_mm256_add_epi16还适用于16位元素
  • _mm256_unpacklo_epi8_mm256_unpackhi_epi8,可能最好通过该页上的图表来解释

可以"垂直"求和,使用一个uint64_t保存64个单独求和的所有第0位,另一个uint64_t保存求和的全部第1位等。可以通过用逐位算术模拟全加法器(电路组件)来进行加法。然后,计数器不再只加0或1,而是同时加上更大的数字。

垂直和也可以向量化,但这会大大扩展将垂直和添加到列和的代码,所以我在这里没有这样做。它应该会有所帮助,但它只是大量的代码。

示例(未测试):

size_t y;
// sum 7 rows at once
for (y = 0; (y + 6) < 15125; y += 7) {
ptr_dot_counts = dot_counts.data( );
ptr_data = ripped_pdf_data.data( ) + y * 218;
for (size_t x = 0; x < 218; x++) {
uint64_t dataA = ptr_data[0];
uint64_t dataB = ptr_data[218];
uint64_t dataC = ptr_data[218 * 2];
uint64_t dataD = ptr_data[218 * 3];
uint64_t dataE = ptr_data[218 * 4];
uint64_t dataF = ptr_data[218 * 5];
uint64_t dataG = ptr_data[218 * 6];
// vertical sums, 7 bits to 3
uint64_t abc0 = (dataA ^ dataB) ^ dataC;
uint64_t abc1 = (dataA ^ dataB) & dataC | (dataA & dataB);
uint64_t def0 = (dataD ^ dataE) ^ dataF;
uint64_t def1 = (dataD ^ dataE) & dataF | (dataD & dataE);
uint64_t bit0 = (abc0 ^ def0) ^ dataG;
uint64_t c1   = (abc0 ^ def0) & dataG | (abc0 & def0);
uint64_t bit1 = (abc1 ^ def1) ^ c1;
uint64_t bit2 = (abc1 ^ def1) & c1 | (abc1 & def1);
// add vertical sums to column counts
__m256i bit0v = _mm256_set1_epi64x(bit0);
__m256i data01 = _mm256_srlv_epi64(bit0v, _mm256_set_epi64x(4, 6, 5, 7));
__m256i data02 = _mm256_srlv_epi64(bit0v, _mm256_set_epi64x(0, 2, 1, 3));
data01 = _mm256_and_si256(data01, _mm256_set1_epi8(1));
data02 = _mm256_and_si256(data02, _mm256_set1_epi8(1));
__m256i bit1v = _mm256_set1_epi64x(bit1);
__m256i data11 = _mm256_srlv_epi64(bit1v, _mm256_set_epi64x(4, 6, 5, 7));
__m256i data12 = _mm256_srlv_epi64(bit1v, _mm256_set_epi64x(0, 2, 1, 3));
data11 = _mm256_and_si256(data11, _mm256_set1_epi8(1));
data12 = _mm256_and_si256(data12, _mm256_set1_epi8(1));
data11 = _mm256_add_epi8(data11, data11);
data12 = _mm256_add_epi8(data12, data12);
__m256i bit2v = _mm256_set1_epi64x(bit2);
__m256i data21 = _mm256_srlv_epi64(bit2v, _mm256_set_epi64x(4, 6, 5, 7));
__m256i data22 = _mm256_srlv_epi64(bit2v, _mm256_set_epi64x(0, 2, 1, 3));
data21 = _mm256_and_si256(data21, _mm256_set1_epi8(1));
data22 = _mm256_and_si256(data22, _mm256_set1_epi8(1));
data21 = _mm256_slli_epi16(data21, 2);
data22 = _mm256_slli_epi16(data22, 2);
__m256i data1 = _mm256_add_epi8(_mm256_add_epi8(data01, data11), data21);
__m256i data2 = _mm256_add_epi8(_mm256_add_epi8(data02, data12), data22);
__m256i zero = _mm256_setzero_si256();
__m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);
c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);
c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);
c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);

ptr_dot_counts += 64;
++ptr_data;
}
}
// leftover rows
for (; y < 15125; y++) {
ptr_dot_counts = dot_counts.data( );
ptr_data = ripped_pdf_data.data( ) + y * 218;
for (size_t x = 0; x < 218; x++) {
__m256i data = _mm256_set1_epi64x(*ptr_data);
__m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
__m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));
__m256i zero = _mm256_setzero_si256();
__m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);
c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);
c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);
c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);

ptr_dot_counts += 64;
++ptr_data;
}
}

到目前为止,第二好的方法是一种更简单的方法,更像第一个版本,除了同时运行yloopLen行以利用快速的8位求和:

size_t yloopLen = 32;
size_t yblock = yloopLen * 1;
size_t yy;
for (yy = 0; yy < 15125; yy += yblock) {
for (size_t x = 0; x < 218; x++) {
ptr_data = ripped_pdf_data.data() + x;
ptr_dot_counts = dot_counts.data() + x * 64;
__m256i zero = _mm256_setzero_si256();
__m256i c1 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
__m256i c2 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
__m256i c3 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
__m256i c4 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
size_t end = std::min(yy + yblock, size_t(15125));
size_t y;
for (y = yy; y < end; y += yloopLen) {
size_t len = std::min(size_t(yloopLen), end - y);
__m256i count1 = zero;
__m256i count2 = zero;
for (size_t t = 0; t < len; t++) {
__m256i data = _mm256_set1_epi64x(ptr_data[(y + t) * 218]);
__m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
__m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));
count1 = _mm256_add_epi8(count1, data1);
count2 = _mm256_add_epi8(count2, data2);
}
c1 = _mm256_add_epi16(_mm256_unpacklo_epi8(count1, zero), c1);
c2 = _mm256_add_epi16(_mm256_unpackhi_epi8(count1, zero), c2);
c3 = _mm256_add_epi16(_mm256_unpacklo_epi8(count2, zero), c3);
c4 = _mm256_add_epi16(_mm256_unpackhi_epi8(count2, zero), c4);
}
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c1);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c2);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c3);
_mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c4);
}
}

以前也有一些测量问题,最终这并没有更好,但也没有比上面更花哨的"垂直和"版本差多少。