
Reproduce _mm256_sllv_epi16 and _mm256_sllv_epi8 in AVX2

我很惊讶地发现_mm256_sllv_epi16/8(__m256i v1, __m256i v2)_mm256_srlv_epi16/8(__m256i v1, __m256i v2)没有出现在《英特尔本质指南》中,而且我找不到任何解决方案来仅用AVX2重新创建AVX512本质。



__m256i v1 = _mm256_set1_epi16(0b1111111111111111);
__m256i v2 = _mm256_setr_epi16(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15);
v1 = _mm256_sllv_epi16(v1, v2);




下面的解决方案_mm256_sllv_epi16_emu不一定比Nyan的解决方案更快,也不一定更好。性能取决于周围的代码和使用的CPU。尽管如此,这里的解决方案可能会引起人们的兴趣,至少在旧的计算机系统上是这样。例如,vpsllvd指令在Nyan的解决方案中使用了两次。此指令在英特尔Skylake系统或更新版本上速度很快。在Intel Broadwell或Haswell上,此指令速度较慢,因为它可以解码到3个微操作。这里的解决方案避免了这种缓慢的指令。



/*     gcc -O3 -m64 -Wall -mavx2 -march=broadwell shift_v_epi8.c     */
#include <immintrin.h>
#include <stdio.h>
int print_epi8(__m256i  a);
int print_epi16(__m256i  a);
__m256i _mm256_sllv_epi8(__m256i a, __m256i count) {
__m256i mask_hi        = _mm256_set1_epi32(0xFF00FF00);
__m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1, 0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1);
__m256i count_sat      = _mm256_min_epu8(count, _mm256_set1_epi8(8));     /* AVX shift counts are not masked. So a_i << n_i = 0 for n_i >= 8. count_sat is always less than 9.*/ 
__m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count_sat);  /* Select the right multiplication factor in the lookup table.                                      */
__m256i x_lo           = _mm256_mullo_epi16(a, multiplier);               /* Unfortunately _mm256_mullo_epi8 doesn't exist. Split the 16 bit elements in a high and low part. */
__m256i multiplier_hi  = _mm256_srli_epi16(multiplier, 8);                /* The multiplier of the high bits.                                                                 */
__m256i a_hi           = _mm256_and_si256(a, mask_hi);                    /* Mask off the low bits.                                                                           */
__m256i x_hi           = _mm256_mullo_epi16(a_hi, multiplier_hi);
__m256i x              = _mm256_blendv_epi8(x_lo, x_hi, mask_hi);         /* Merge the high and low part.                                                                     */
return x;

__m256i _mm256_srlv_epi8(__m256i a, __m256i count) {
__m256i mask_hi        = _mm256_set1_epi32(0xFF00FF00);
__m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 1,2,4,8, 16,32,64,128, 0,0,0,0, 0,0,0,0, 1,2,4,8, 16,32,64,128);
__m256i count_sat      = _mm256_min_epu8(count, _mm256_set1_epi8(8));     /* AVX shift counts are not masked. So a_i >> n_i = 0 for n_i >= 8. count_sat is always less than 9.*/ 
__m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count_sat);  /* Select the right multiplication factor in the lookup table.                                      */
__m256i a_lo           = _mm256_andnot_si256(mask_hi, a);                 /* Mask off the high bits.                                                                          */
__m256i multiplier_lo  = _mm256_andnot_si256(mask_hi, multiplier);        /* The multiplier of the low bits.                                                                  */
__m256i x_lo           = _mm256_mullo_epi16(a_lo, multiplier_lo);         /* Shift left a_lo by multiplying.                                                                  */
x_lo           = _mm256_srli_epi16(x_lo, 7);                      /* Shift right by 7 to get the low bits at the right position.                                      */
__m256i multiplier_hi  = _mm256_and_si256(mask_hi, multiplier);           /* The multiplier of the high bits.                                                                 */
__m256i x_hi           = _mm256_mulhi_epu16(a, multiplier_hi);            /* Variable shift left a_hi by multiplying. Use a instead of a_hi because the a_lo bits don't interfere */
x_hi           = _mm256_slli_epi16(x_hi, 1);                      /* Shift left by 1 to get the high bits at the right position.                                      */
__m256i x              = _mm256_blendv_epi8(x_lo, x_hi, mask_hi);         /* Merge the high and low part.                                                                     */
return x;

__m256i _mm256_sllv_epi16_emu(__m256i a, __m256i count) {
__m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1, 0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1);
__m256i byte_shuf_mask = _mm256_set_epi8(14,14,12,12, 10,10,8,8, 6,6,4,4, 2,2,0,0, 14,14,12,12, 10,10,8,8, 6,6,4,4, 2,2,0,0);
__m256i mask_lt_15     = _mm256_cmpgt_epi16(_mm256_set1_epi16(16), count);
a              = _mm256_and_si256(mask_lt_15, a);                    /* Set a to zero if count > 15.                                                                      */
count          = _mm256_shuffle_epi8(count, byte_shuf_mask);         /* Duplicate bytes from the even postions to bytes at the even and odd positions.                    */
count          = _mm256_sub_epi8(count,_mm256_set1_epi16(0x0800));   /* Subtract 8 at the even byte positions. Note that the vpshufb instruction selects a zero byte if the shuffle control mask is negative.     */
__m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count);         /* Select the right multiplication factor in the lookup table. Within the 16 bit elements, only the upper byte or the lower byte is nonzero. */
__m256i x              = _mm256_mullo_epi16(a, multiplier);                  
return x;

int main(){
printf("Emulating _mm256_sllv_epi8:n");
__m256i a     = _mm256_set_epi8(32,31,30,29, 28,27,26,25, 24,23,22,21, 20,19,18,17, 16,15,14,13, 12,11,10,9, 8,7,6,5, 4,3,2,1);
__m256i count = _mm256_set_epi8(7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0);
__m256i x     = _mm256_sllv_epi8(a, count);
printf("a     = n"); print_epi8(a    );
printf("count = n"); print_epi8(count);
printf("x     = n"); print_epi8(x    );

printf("Emulating _mm256_srlv_epi8:n");
a     = _mm256_set_epi8(223,224,225,226, 227,228,229,230, 231,232,233,234, 235,236,237,238, 239,240,241,242, 243,244,245,246, 247,248,249,250, 251,252,253,254);
count = _mm256_set_epi8(7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0);
x     = _mm256_srlv_epi8(a, count);
printf("a     = n"); print_epi8(a    );
printf("count = n"); print_epi8(count);
printf("x     = n"); print_epi8(x    );

printf("Emulating _mm256_sllv_epi16:n");
a     = _mm256_set_epi16(1601,1501,1401,1301, 1200,1100,1000,900, 800,700,600,500, 400,300,200,100);
count = _mm256_set_epi16(17,16,15,13,  11,10,9,8, 7,6,5,4, 3,2,1,0);
x     = _mm256_sllv_epi16_emu(a, count);
printf("a     = n"); print_epi16(a    );
printf("count = n"); print_epi16(count);
printf("x     = n"); print_epi16(x    );
return 0;

int print_epi8(__m256i  a){
char v[32];
int i;
_mm256_storeu_si256((__m256i *)v,a);
for (i = 0; i<32; i++) printf("%4hhu",v[i]);
return 0;
int print_epi16(__m256i  a){
unsigned short int  v[16];
int i;
_mm256_storeu_si256((__m256i *)v,a);
for (i = 0; i<16; i++) printf("%6hu",v[i]);
return 0;


Emulating _mm256_sllv_epi8:
a     = 
1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32
count = 
0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7
x     = 
1   4  12  32  80 192 192   0   0   0   0   0  13  28  60 128  16  64 192   0   0   0   0   0  25  52 108 224 208 192 192   0

Emulating _mm256_srlv_epi8:
a     = 
254 253 252 251 250 249 248 247 246 245 244 243 242 241 240 239 238 237 236 235 234 233 232 231 230 229 228 227 226 225 224 223
count = 
0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7
x     = 
254 126  63  31  15   7   3   1   0   0   0   0 242 120  60  29  14   7   3   1   0   0   0   0 230 114  57  28  14   7   3   1

Emulating _mm256_sllv_epi16:
a     = 
100   200   300   400   500   600   700   800   900  1000  1100  1200  1301  1401  1501  1601
count = 
0     1     2     3     4     5     6     7     8     9    10    11    13    15    16    17
x     = 
100   400  1200  3200  8000 19200 44800 36864 33792 53248 12288 32768 40960 32768     0     0






__m256i _mm256_sllv_epi16(__m256i a, __m256i count) {
const __m256i mask = _mm256_set1_epi32(0xffff0000);
__m256i low_half = _mm256_sllv_epi32(
_mm256_andnot_si256(mask, count)
__m256i high_half = _mm256_sllv_epi32(
_mm256_and_si256(mask, a),
_mm256_srli_epi32(count, 16)
return _mm256_blend_epi16(low_half, high_half, 0xaa);
__m256i _mm256_sllv_epi16(__m256i a, __m256i count) {
const __m256i mask = _mm256_set1_epi32(0xffff0000); // alternating low/high words of a dword
// shift low word of each dword: low_half = (a << (count & 0xffff)) [for each 32b element]
// note that, because `a` isn't being masked here, we may get some "junk" bits, but these will get eliminated by the blend below
__m256i low_half = _mm256_sllv_epi32(
_mm256_andnot_si256(mask, count)
// shift high word of each dword: high_half = ((a & 0xffff0000) << (count >> 16)) [for each 32b element]
__m256i high_half = _mm256_sllv_epi32(
_mm256_and_si256(mask, a),     // make sure we shift in zeros
_mm256_srli_epi32(count, 16)   // need the high-16 count at the bottom of a 32-bit element
// combine low and high words
return _mm256_blend_epi16(low_half, high_half, 0xaa);
__m256i _mm256_srlv_epi16(__m256i a, __m256i count) {
const __m256i mask = _mm256_set1_epi32(0x0000ffff);
__m256i low_half = _mm256_srlv_epi32(
_mm256_and_si256(mask, a),
_mm256_and_si256(mask, count)
__m256i high_half = _mm256_srlv_epi32(
_mm256_srli_epi32(count, 16)
return _mm256_blend_epi16(low_half, high_half, 0xaa);

GCC 8.2将其编译为您所期望的:

_mm256_srlv_epi16(long long __vector(4), long long __vector(4)):
vmovdqa       ymm3, YMMWORD PTR .LC0[rip]
vpand   ymm2, ymm0, ymm3
vpand   ymm3, ymm1, ymm3
vpsrld  ymm1, ymm1, 16
vpsrlvd ymm2, ymm2, ymm3
vpsrlvd ymm0, ymm0, ymm1
vpblendw        ymm0, ymm2, ymm0, 170
_mm256_sllv_epi16(long long __vector(4), long long __vector(4)):
vmovdqa       ymm3, YMMWORD PTR .LC1[rip]
vpandn  ymm2, ymm3, ymm1
vpsrld  ymm1, ymm1, 16
vpsllvd ymm2, ymm0, ymm2
vpand   ymm0, ymm0, ymm3
vpsllvd ymm0, ymm0, ymm1
vpblendw        ymm0, ymm2, ymm0, 170

意味着模拟结果为1x加载+2x AND/ANDN+2x可变移位+1x右移+1x混合。

Clang 6.0做了一些有趣的事情——它通过使用混合物来消除内存负载(以及相应的掩蔽):

_mm256_sllv_epi16(long long __vector(4), long long __vector(4)):
vpxor   xmm2, xmm2, xmm2
vpblendw        ymm3, ymm1, ymm2, 170
vpsllvd ymm3, ymm0, ymm3
vpsrld  ymm1, ymm1, 16
vpblendw        ymm0, ymm2, ymm0, 170
vpsllvd ymm0, ymm0, ymm1
vpblendw        ymm0, ymm3, ymm0, 170
_mm256_srlv_epi16(long long __vector(4), long long __vector(4)):
vpxor   xmm2, xmm2, xmm2
vpblendw        ymm3, ymm0, ymm2, 170
vpblendw        ymm2, ymm1, ymm2, 170
vpsrlvd ymm2, ymm3, ymm2
vpsrld  ymm1, ymm1, 16
vpsrlvd ymm0, ymm0, ymm1
vpblendw        ymm0, ymm2, ymm0, 170




__m256i v1 = _mm256_set1_epi16(0b1111111111111111);
__m256i v2 = _mm256_setr_epi16(1<<0,1<<1,1<<2,1<<3,1<<4,1<<5,1<<6,1<<7,1<<8,1<<9,1<<10,1<<11,1<<12,1<<13,1<<14,1<<15);
v1 = _mm256_mullo_epi16(v1, v2);




__m256i _mm256_sllv_epi8(__m256i a, __m256i count) {
const __m256i mask = _mm256_set1_epi16(0xff00);
__m256i low_half = _mm256_sllv_epi16(
_mm256_andnot_si256(mask, count)
__m256i high_half = _mm256_sllv_epi16(
_mm256_and_si256(mask, a),
_mm256_srli_epi16(count, 8)
return _mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00));
__m256i _mm256_srlv_epi8(__m256i a, __m256i count) {
const __m256i mask = _mm256_set1_epi16(0x00ff);
__m256i low_half = _mm256_srlv_epi16(
_mm256_and_si256(mask, a),
_mm256_and_si256(mask, count)
__m256i high_half = _mm256_srlv_epi16(
_mm256_srli_epi16(count, 8)
return _mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00));

(Peter Cordes在下面提到,在纯AVX512BW(+VL)实现中,_mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00))可以替换为_mm256_mask_blend_epi8(0xaaaaaaaa, low_half, high_half),这可能更快)

