在AVX2中再现_mm256_sllv_epi16和_mm256-slv_epi8

Reproduce _mm256_sllv_epi16 and _mm256_sllv_epi8 in AVX2

本文关键字:epi16 mm256-slv epi8 sllv mm256 AVX2      更新时间:2023-10-16

我很惊讶地发现_mm256_sllv_epi16/8(__m256i v1, __m256i v2)_mm256_srlv_epi16/8(__m256i v1, __m256i v2)没有出现在《英特尔本质指南》中,而且我找不到任何解决方案来仅用AVX2重新创建AVX512本质。

此函数将所有16/8位的压缩int左移v2中相应数据元素的计数值。

epi16示例:

__m256i v1 = _mm256_set1_epi16(0b1111111111111111);
__m256i v2 = _mm256_setr_epi16(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15);
v1 = _mm256_sllv_epi16(v1, v2);

则v1等于->(1111111111111111,11111111111111110,1111111111111100,1111111111111000,。。。。。。。。。。。。。。。。,1000000000000000);

_mm256_sllv_epi8的情况下,使用pshufb指令作为一个小型查找表,用乘法替换移位并不太困难。还可以通过乘法和许多其他指令模拟_mm256_srlv_epi8的右移,请参阅下面的代码。我希望_mm256_sllv_epi8至少比Nyan的解决方案更有效。


或多或少可以使用相同的想法来模拟_mm256_sllv_epi16,但在这种情况下,选择正确的乘法器就不那么简单了(另请参阅下面的代码)。

下面的解决方案_mm256_sllv_epi16_emu不一定比Nyan的解决方案更快,也不一定更好。性能取决于周围的代码和使用的CPU。尽管如此,这里的解决方案可能会引起人们的兴趣,至少在旧的计算机系统上是这样。例如,vpsllvd指令在Nyan的解决方案中使用了两次。此指令在英特尔Skylake系统或更新版本上速度很快。在Intel Broadwell或Haswell上,此指令速度较慢,因为它可以解码到3个微操作。这里的解决方案避免了这种缓慢的指令。

如果已知移位计数小于或等于15,则可以跳过具有mask_lt_15的两行代码。

缺失的内在_mm256_srlv_epi16留给读者练习。


/*     gcc -O3 -m64 -Wall -mavx2 -march=broadwell shift_v_epi8.c     */
#include <immintrin.h>
#include <stdio.h>
int print_epi8(__m256i  a);
int print_epi16(__m256i  a);
__m256i _mm256_sllv_epi8(__m256i a, __m256i count) {
__m256i mask_hi        = _mm256_set1_epi32(0xFF00FF00);
__m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1, 0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1);
__m256i count_sat      = _mm256_min_epu8(count, _mm256_set1_epi8(8));     /* AVX shift counts are not masked. So a_i << n_i = 0 for n_i >= 8. count_sat is always less than 9.*/ 
__m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count_sat);  /* Select the right multiplication factor in the lookup table.                                      */
__m256i x_lo           = _mm256_mullo_epi16(a, multiplier);               /* Unfortunately _mm256_mullo_epi8 doesn't exist. Split the 16 bit elements in a high and low part. */
__m256i multiplier_hi  = _mm256_srli_epi16(multiplier, 8);                /* The multiplier of the high bits.                                                                 */
__m256i a_hi           = _mm256_and_si256(a, mask_hi);                    /* Mask off the low bits.                                                                           */
__m256i x_hi           = _mm256_mullo_epi16(a_hi, multiplier_hi);
__m256i x              = _mm256_blendv_epi8(x_lo, x_hi, mask_hi);         /* Merge the high and low part.                                                                     */
return x;
}

__m256i _mm256_srlv_epi8(__m256i a, __m256i count) {
__m256i mask_hi        = _mm256_set1_epi32(0xFF00FF00);
__m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 1,2,4,8, 16,32,64,128, 0,0,0,0, 0,0,0,0, 1,2,4,8, 16,32,64,128);
__m256i count_sat      = _mm256_min_epu8(count, _mm256_set1_epi8(8));     /* AVX shift counts are not masked. So a_i >> n_i = 0 for n_i >= 8. count_sat is always less than 9.*/ 
__m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count_sat);  /* Select the right multiplication factor in the lookup table.                                      */
__m256i a_lo           = _mm256_andnot_si256(mask_hi, a);                 /* Mask off the high bits.                                                                          */
__m256i multiplier_lo  = _mm256_andnot_si256(mask_hi, multiplier);        /* The multiplier of the low bits.                                                                  */
__m256i x_lo           = _mm256_mullo_epi16(a_lo, multiplier_lo);         /* Shift left a_lo by multiplying.                                                                  */
x_lo           = _mm256_srli_epi16(x_lo, 7);                      /* Shift right by 7 to get the low bits at the right position.                                      */
__m256i multiplier_hi  = _mm256_and_si256(mask_hi, multiplier);           /* The multiplier of the high bits.                                                                 */
__m256i x_hi           = _mm256_mulhi_epu16(a, multiplier_hi);            /* Variable shift left a_hi by multiplying. Use a instead of a_hi because the a_lo bits don't interfere */
x_hi           = _mm256_slli_epi16(x_hi, 1);                      /* Shift left by 1 to get the high bits at the right position.                                      */
__m256i x              = _mm256_blendv_epi8(x_lo, x_hi, mask_hi);         /* Merge the high and low part.                                                                     */
return x;
}

__m256i _mm256_sllv_epi16_emu(__m256i a, __m256i count) {
__m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1, 0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1);
__m256i byte_shuf_mask = _mm256_set_epi8(14,14,12,12, 10,10,8,8, 6,6,4,4, 2,2,0,0, 14,14,12,12, 10,10,8,8, 6,6,4,4, 2,2,0,0);
__m256i mask_lt_15     = _mm256_cmpgt_epi16(_mm256_set1_epi16(16), count);
a              = _mm256_and_si256(mask_lt_15, a);                    /* Set a to zero if count > 15.                                                                      */
count          = _mm256_shuffle_epi8(count, byte_shuf_mask);         /* Duplicate bytes from the even postions to bytes at the even and odd positions.                    */
count          = _mm256_sub_epi8(count,_mm256_set1_epi16(0x0800));   /* Subtract 8 at the even byte positions. Note that the vpshufb instruction selects a zero byte if the shuffle control mask is negative.     */
__m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count);         /* Select the right multiplication factor in the lookup table. Within the 16 bit elements, only the upper byte or the lower byte is nonzero. */
__m256i x              = _mm256_mullo_epi16(a, multiplier);                  
return x;
}

int main(){
printf("Emulating _mm256_sllv_epi8:n");
__m256i a     = _mm256_set_epi8(32,31,30,29, 28,27,26,25, 24,23,22,21, 20,19,18,17, 16,15,14,13, 12,11,10,9, 8,7,6,5, 4,3,2,1);
__m256i count = _mm256_set_epi8(7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0);
__m256i x     = _mm256_sllv_epi8(a, count);
printf("a     = n"); print_epi8(a    );
printf("count = n"); print_epi8(count);
printf("x     = n"); print_epi8(x    );
printf("nn"); 

printf("Emulating _mm256_srlv_epi8:n");
a     = _mm256_set_epi8(223,224,225,226, 227,228,229,230, 231,232,233,234, 235,236,237,238, 239,240,241,242, 243,244,245,246, 247,248,249,250, 251,252,253,254);
count = _mm256_set_epi8(7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0);
x     = _mm256_srlv_epi8(a, count);
printf("a     = n"); print_epi8(a    );
printf("count = n"); print_epi8(count);
printf("x     = n"); print_epi8(x    );
printf("nn"); 

printf("Emulating _mm256_sllv_epi16:n");
a     = _mm256_set_epi16(1601,1501,1401,1301, 1200,1100,1000,900, 800,700,600,500, 400,300,200,100);
count = _mm256_set_epi16(17,16,15,13,  11,10,9,8, 7,6,5,4, 3,2,1,0);
x     = _mm256_sllv_epi16_emu(a, count);
printf("a     = n"); print_epi16(a    );
printf("count = n"); print_epi16(count);
printf("x     = n"); print_epi16(x    );
printf("nn"); 
return 0;
}

int print_epi8(__m256i  a){
char v[32];
int i;
_mm256_storeu_si256((__m256i *)v,a);
for (i = 0; i<32; i++) printf("%4hhu",v[i]);
printf("n");
return 0;
}
int print_epi16(__m256i  a){
unsigned short int  v[16];
int i;
_mm256_storeu_si256((__m256i *)v,a);
for (i = 0; i<16; i++) printf("%6hu",v[i]);
printf("n");
return 0;
}

输出为:

Emulating _mm256_sllv_epi8:
a     = 
1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32
count = 
0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7
x     = 
1   4  12  32  80 192 192   0   0   0   0   0  13  28  60 128  16  64 192   0   0   0   0   0  25  52 108 224 208 192 192   0

Emulating _mm256_srlv_epi8:
a     = 
254 253 252 251 250 249 248 247 246 245 244 243 242 241 240 239 238 237 236 235 234 233 232 231 230 229 228 227 226 225 224 223
count = 
0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7
x     = 
254 126  63  31  15   7   3   1   0   0   0   0 242 120  60  29  14   7   3   1   0   0   0   0 230 114  57  28  14   7   3   1

Emulating _mm256_sllv_epi16:
a     = 
100   200   300   400   500   600   700   800   900  1000  1100  1200  1301  1401  1501  1601
count = 
0     1     2     3     4     5     6     7     8     9    10    11    13    15    16    17
x     = 
100   400  1200  3200  8000 19200 44800 36864 33792 53248 12288 32768 40960 32768     0     0

确实缺少一些AVX2指令。然而,请注意,通过模拟"缺失"的AVX2指令来填补这些空白并不总是一个好主意。有时是以避免这些模拟指令的方式重新设计代码更有效。例如,通过使用更宽的矢量元素(_epi32而不是_epi16)。

奇怪的是,他们错过了这一点,尽管许多AVX整数指令似乎只适用于32/64位宽。AVX512BW中至少添加了16位(尽管我仍然不明白英特尔为什么拒绝添加8位移位)。

我们可以通过使用32位变量移位和一些掩蔽和混合,仅使用AVX2来模拟16位变量移位。

我们需要在包含每个16位元素的32位元素的底部进行右移计数,这可以通过AND(对于低元素)和立即移位(对于高半部分)来完成。(与标量移位不同,x86矢量移位使其计数饱和,而不是包装/屏蔽)。

在进行高半移位之前,我们还需要屏蔽数据的低16位,这样我们就不会将垃圾移位到包含32位元素的高16位半中。

__m256i _mm256_sllv_epi16(__m256i a, __m256i count) {
const __m256i mask = _mm256_set1_epi32(0xffff0000);
__m256i low_half = _mm256_sllv_epi32(
a,
_mm256_andnot_si256(mask, count)
);
__m256i high_half = _mm256_sllv_epi32(
_mm256_and_si256(mask, a),
_mm256_srli_epi32(count, 16)
);
return _mm256_blend_epi16(low_half, high_half, 0xaa);
}
__m256i _mm256_sllv_epi16(__m256i a, __m256i count) {
const __m256i mask = _mm256_set1_epi32(0xffff0000); // alternating low/high words of a dword
// shift low word of each dword: low_half = (a << (count & 0xffff)) [for each 32b element]
// note that, because `a` isn't being masked here, we may get some "junk" bits, but these will get eliminated by the blend below
__m256i low_half = _mm256_sllv_epi32(
a,
_mm256_andnot_si256(mask, count)
);
// shift high word of each dword: high_half = ((a & 0xffff0000) << (count >> 16)) [for each 32b element]
__m256i high_half = _mm256_sllv_epi32(
_mm256_and_si256(mask, a),     // make sure we shift in zeros
_mm256_srli_epi32(count, 16)   // need the high-16 count at the bottom of a 32-bit element
);
// combine low and high words
return _mm256_blend_epi16(low_half, high_half, 0xaa);
}
__m256i _mm256_srlv_epi16(__m256i a, __m256i count) {
const __m256i mask = _mm256_set1_epi32(0x0000ffff);
__m256i low_half = _mm256_srlv_epi32(
_mm256_and_si256(mask, a),
_mm256_and_si256(mask, count)
);
__m256i high_half = _mm256_srlv_epi32(
a,
_mm256_srli_epi32(count, 16)
);
return _mm256_blend_epi16(low_half, high_half, 0xaa);
}

GCC 8.2将其编译为您所期望的:

_mm256_srlv_epi16(long long __vector(4), long long __vector(4)):
vmovdqa       ymm3, YMMWORD PTR .LC0[rip]
vpand   ymm2, ymm0, ymm3
vpand   ymm3, ymm1, ymm3
vpsrld  ymm1, ymm1, 16
vpsrlvd ymm2, ymm2, ymm3
vpsrlvd ymm0, ymm0, ymm1
vpblendw        ymm0, ymm2, ymm0, 170
ret
_mm256_sllv_epi16(long long __vector(4), long long __vector(4)):
vmovdqa       ymm3, YMMWORD PTR .LC1[rip]
vpandn  ymm2, ymm3, ymm1
vpsrld  ymm1, ymm1, 16
vpsllvd ymm2, ymm0, ymm2
vpand   ymm0, ymm0, ymm3
vpsllvd ymm0, ymm0, ymm1
vpblendw        ymm0, ymm2, ymm0, 170
ret

意味着模拟结果为1x加载+2x AND/ANDN+2x可变移位+1x右移+1x混合。

Clang 6.0做了一些有趣的事情——它通过使用混合物来消除内存负载(以及相应的掩蔽):

_mm256_sllv_epi16(long long __vector(4), long long __vector(4)):
vpxor   xmm2, xmm2, xmm2
vpblendw        ymm3, ymm1, ymm2, 170
vpsllvd ymm3, ymm0, ymm3
vpsrld  ymm1, ymm1, 16
vpblendw        ymm0, ymm2, ymm0, 170
vpsllvd ymm0, ymm0, ymm1
vpblendw        ymm0, ymm3, ymm0, 170
ret
_mm256_srlv_epi16(long long __vector(4), long long __vector(4)):
vpxor   xmm2, xmm2, xmm2
vpblendw        ymm3, ymm0, ymm2, 170
vpblendw        ymm2, ymm1, ymm2, 170
vpsrlvd ymm2, ymm3, ymm2
vpsrld  ymm1, ymm1, 16
vpsrlvd ymm0, ymm0, ymm1
vpblendw        ymm0, ymm2, ymm0, 170
ret

这导致:1次清除+3次混合+2次可变移位+1次右移。

我还没有对哪种方法更快进行任何基准测试,但我怀疑这可能取决于CPU,特别是CPU上PBLENDW的成本。

当然,如果您的用例有点受约束,则可以简化以上内容,例如,如果偏移量都是常量,则可以删除使其工作所需的屏蔽/偏移(假设编译器不会自动为您执行此操作)
对于左移,如果偏移量是恒定的,您可以使用_mm256_mullo_epi16,将偏移量转换为可以相乘的值,例如您给出的示例:

__m256i v1 = _mm256_set1_epi16(0b1111111111111111);
__m256i v2 = _mm256_setr_epi16(1<<0,1<<1,1<<2,1<<3,1<<4,1<<5,1<<6,1<<7,1<<8,1<<9,1<<10,1<<11,1<<12,1<<13,1<<14,1<<15);
v1 = _mm256_mullo_epi16(v1, v2);

更新:Peter提到(见下面的评论),右移也可以用_mm256_mulhi_epi16实现(例如,执行v>>1v乘以1<<15并取高位字)。


对于8位变量移位,AVX512中也不存在这种情况(同样,我不知道英特尔为什么没有8位SIMD移位)
如果AVX512BW可用,则可以使用与上述类似的技巧,使用_mm256_sllv_epi16对于AVX2,我想不出比第二次应用16位模拟更好的方法了,因为您最终必须进行32位移位的4倍移位请参阅@wim的答案,了解AVX2中8位的良好解决方案。

这就是我想到的(基本上是AVX512上8位采用的16位版本):

__m256i _mm256_sllv_epi8(__m256i a, __m256i count) {
const __m256i mask = _mm256_set1_epi16(0xff00);
__m256i low_half = _mm256_sllv_epi16(
a,
_mm256_andnot_si256(mask, count)
);
__m256i high_half = _mm256_sllv_epi16(
_mm256_and_si256(mask, a),
_mm256_srli_epi16(count, 8)
);
return _mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00));
}
__m256i _mm256_srlv_epi8(__m256i a, __m256i count) {
const __m256i mask = _mm256_set1_epi16(0x00ff);
__m256i low_half = _mm256_srlv_epi16(
_mm256_and_si256(mask, a),
_mm256_and_si256(mask, count)
);
__m256i high_half = _mm256_srlv_epi16(
a,
_mm256_srli_epi16(count, 8)
);
return _mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00));
}

(Peter Cordes在下面提到,在纯AVX512BW(+VL)实现中,_mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00))可以替换为_mm256_mask_blend_epi8(0xaaaaaaaa, low_half, high_half),这可能更快)

相关文章:
  • 没有找到相关文章