如何从AVX寄存器中获取数据

How to get data out of AVX registers?

本文关键字:获取 数据 寄存器 AVX      更新时间:2023-10-16

使用MSVC 2013和AVX 1,我在一个寄存器中有8个浮点:

__m256 foo = mm256_fmadd_ps(a,b,c);

现在,我想为所有8个浮点调用inline void print(float) {...}。看起来英特尔AVX的本质会让这件事变得相当复杂:

print(_castu32_f32(_mm256_extract_epi32(foo, 0)));
print(_castu32_f32(_mm256_extract_epi32(foo, 1)));
print(_castu32_f32(_mm256_extract_epi32(foo, 2)));
// ...

但是MSVC甚至没有这两个本质。当然,我可以将值写回内存并从那里加载,但我怀疑在程序集级别不需要溢出寄存器。

奖金Q:我当然想写

for(int i = 0; i !=8; ++i) 
print(_castu32_f32(_mm256_extract_epi32(foo, i)))

但是MSVC不理解许多内部函数需要循环展开。如何在__m256 foo中的8x32浮点上编写循环?

假设你只有AVX(即没有AVX2),那么你可以做这样的事情:

float extract_float(const __m128 v, const int i)
{
float x;
_MM_EXTRACT_FLOAT(x, v, i);
return x;
}
void print(const __m128 v)
{
print(extract_float(v, 0));
print(extract_float(v, 1));
print(extract_float(v, 2));
print(extract_float(v, 3));
}
void print(const __m256 v)
{
print(_mm256_extractf128_ps(v, 0));
print(_mm256_extractf128_ps(v, 1));
}

然而,我想我可能只会使用一个联盟:

union U256f {
__m256 v;
float a[8];
};
void print(const __m256 v)
{
const U256f u = { v };
for (int i = 0; i < 8; ++i)
print(u.a[i]);
}

小心:_mm256_fmadd_ps不是AVX1的一部分。FMA3有自己的功能位,并且是在英特尔和Haswell一起推出的。AMD推出了带有Piledriver的FMA3(AVX1+FMA4+FMA3,无AVX2)。


在asm级别,如果您想将八个32位元素放入整数寄存器,那么实际上存储到堆栈然后进行标量加载会更快。pextrd是一个关于SnB系列和推土机系列的2-uop指令。(以及不支持AVX的Nehalem和Silvermont)。

唯一一个vextractf128+2xmovd+6xpextrd不可怕的CPU是AMD Jaguar。(廉价的pextrd,只有一个加载端口。)(参见Agner Fog的insn表)

宽对齐的存储可以转发到重叠的窄负载。(当然,您可以使用movd来获得低元素,因此您可以混合使用加载端口和ALU端口uops)。


当然,您似乎是通过使用整数提取来提取floats,然后将其转换回浮点值那太可怕了。

您实际需要的是每个float在其自己的xmm寄存器的低元素中。vextractf128显然是开始的方式,将元素4放在新xmm reg的底部。然后6xAVXshufps可以很容易地获得每一半的其他三个元素。(或者movshdupmovhlps具有较短的编码:没有立即字节)。

与1个存储和7个加载uop相比,7个shuffle uop值得考虑,但如果您无论如何都要为函数调用溢出向量,则不值得考虑。


ABI注意事项:

在Windows中,xmm6-15保留调用(只有low128;ymm6-15的上半部分保留调用)。这是从vextractf128开始的另一个原因。

在SysV ABI中,所有xmm/ymm/zmm寄存器都被调用阻塞,因此每个print()函数都需要一个溢出/重新加载。在那里唯一明智的做法是存储到内存中,并用原始向量调用print(即打印低元素,因为它将忽略寄存器的其余部分)。然后movss xmm0, [rsp+4]并调用第二个元素上的print

将所有8个浮点都很好地解压到8个矢量regs中是没有好处的,因为在第一个函数调用之前,它们都必须单独溢出!

float valueAVX(__m256 a, int i){
float ret = 0;
switch (i){
case 0:
//                 a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)      ( a3, a2, a1, a0 )
// cvtss_f32             a0 
ret = _mm_cvtss_f32(_mm256_extractf128_ps(a, 0));
break;
case 1: {
//                     a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)     lo = ( a3, a2, a1, a0 )
// shuffle(lo, lo, 1)      ( - , a3, a2, a1 )
// cvtss_f32                 a1 
__m128 lo = _mm256_extractf128_ps(a, 0);
ret = _mm_cvtss_f32(_mm_shuffle_ps(lo, lo, 1));
}
break;
case 2: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)   lo = ( a3, a2, a1, a0 )
// movehl(lo, lo)        ( - , - , a3, a2 )
// cvtss_f32               a2 
__m128 lo = _mm256_extractf128_ps(a, 0);
ret = _mm_cvtss_f32(_mm_movehl_ps(lo, lo));
}
break;
case 3: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)   lo = ( a3, a2, a1, a0 )
// shuffle(lo, lo, 3)    ( - , - , - , a3 )
// cvtss_f32               a3 
__m128 lo = _mm256_extractf128_ps(a, 0);                    
ret = _mm_cvtss_f32(_mm_shuffle_ps(lo, lo, 3));
}
break;
case 4:
//                 a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)      ( a7, a6, a5, a4 )
// cvtss_f32             a4 
ret = _mm_cvtss_f32(_mm256_extractf128_ps(a, 1));
break;
case 5: {
//                     a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)     hi = ( a7, a6, a5, a4 )
// shuffle(hi, hi, 1)      ( - , a7, a6, a5 )
// cvtss_f32                 a5 
__m128 hi = _mm256_extractf128_ps(a, 1);
ret = _mm_cvtss_f32(_mm_shuffle_ps(hi, hi, 1));
}
break;
case 6: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)   hi = ( a7, a6, a5, a4 )
// movehl(hi, hi)        ( - , - , a7, a6 )
// cvtss_f32               a6 
__m128 hi = _mm256_extractf128_ps(a, 1);
ret = _mm_cvtss_f32(_mm_movehl_ps(hi, hi));
}
break;
case 7: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)   hi = ( a7, a6, a5, a4 )
// shuffle(hi, hi, 3)    ( - , - , - , a7 )
// cvtss_f32               a7 
__m128 hi = _mm256_extractf128_ps(a, 1);
ret = _mm_cvtss_f32(_mm_shuffle_ps(hi, hi, 3));
}
break;
}
return ret;
}

有关asm的详细信息,请参阅我的另一个答案。这个答案是关于C++方面的。


void foo(__m256 v) {
alignas(32) float vecbuf[8];   // 32-byte aligned array allows aligned store
// avoiding the risk of cache-line splits
_mm256_store_ps(vecbuf, v);
float v0 = _mm_cvtss_f32(_mm256_castps256_ps128(v));  // the bottom of the register
float v1 = vecbuf[1];
float v2 = vecbuf[2];
...
// or loop over vecbuf[i]
// if you do need all 8 elements one at a time, this is a good way
}

或者在CCD_ 21上循环。矢量存储可以转发到其元素之一的标量重载,因此这只引入了大约6个延迟周期,并且可以同时进行多个重载。(因此,它对具有2/时钟负载吞吐量的现代CPU的吞吐量非常有利。)

请注意,我避免了重新加载低元素;寄存器中已经的向量的低元素是标量CCD_ 22。_mm_cvtss_f32( _mm256_castps256_ps128(v) )就是让编译器的类型系统满意的简单方法;它编译为零asm指令,因此它实际上是免费的(排除遗漏的优化错误)。(请参阅英特尔的内部指南)。XMM寄存器是相应YMM寄存器的低位128,标量浮点/双精度是XMM寄存器的低32或64位。(上半部分的垃圾无关紧要。)

铸造第一个曾经给OoO高管一些事情做,同时等待其他人的到来。您可以考虑混洗以获得vunpckhpsvmovhlps处于低位128的第二个元素,因此您可以快速准备好2个元素,如果这有助于填补延迟气泡的话。

在GNU C/C++中,您可以用v[1]甚至像v[i]这样的变量索引来索引像数组这样的向量类型。编译器将在shuffle或store/reload之间进行选择。

但这对MSVC来说是不可移植的,MSVC根据与一些命名成员的联合来定义__m256

存储到数组并重新加载是可移植的,编译器有时甚至可以将其优化为shuffle(如果你不想这样,请检查生成的asm。)

例如clang优化了仅将CCD_ 29返回为简单vshufps的函数。https://godbolt.org/z/tHJH_V


如果您真的想把向量的所有元素加成标量总和,请shuffle和SIMD相加。在x86 上进行水平浮点矢量求和的最快方法

(对于单个向量的元素的乘法、最小、最大或其他关联归约也是如此。当然,如果你有多个向量,可以垂直运算到一个向量,比如_mm256_add_ps(v1,v2))


使用Agner Fog的矢量类库,他的包装器类重载operator[],使其以您所期望的方式工作,即使对于非常量args也是如此。这通常编译为存储/重载,但它使用C++编写代码变得容易。启用优化后,您可能会得到不错的结果。(除了低元素可能会被存储/重新加载,而不仅仅是在适当的地方使用。所以你可能想把vec[0]特殊化为_mm_cvtss_f32(vec)或其他什么。)

(VCL过去是根据GPL许可的,但现在的版本是一个简单的Apache许可。)

另请参阅我的github repo,其中对Agner的VCL进行了大部分未经测试的更改,以为某些函数生成更好的代码。


有一个_MM_EXTRACT_FLOAT包装器宏,但它很奇怪,只使用SSE4.1定义。我认为这是为了配合SSE4.1extractps(它可以将浮点的二进制表示提取到整数寄存器中,或者存储到内存中)。不过,当目标是float时,它gcc确实将其编译为FP shuffle。如果您希望结果为float,请注意其他编译器不会将其编译为实际的extractps指令,因为extractps不是这样做的。(这就是insertps的作用,但更简单的FP混洗会占用更少的指令字节。例如,带有AVX的shufps非常棒。)

这很奇怪,因为它需要3个参数:_MM_EXTRACT_FLOAT(dest, src_m128, idx),所以您甚至不能将它用作float局部的初始值设定项。


在矢量上循环

gcc将为您展开类似的循环,但只能使用-O1或更高版本。在-O0,它会给您一条错误消息。

float bad_hsum(__m128 & fv) {
float sum = 0;
for (int i=0 ; i<4 ; i++) {
float f;
_MM_EXTRACT_FLOAT(f, fv, i);  // works only with -O1 or higher
sum += f;
}
return sum;
}

在visual studio上。。我试过下面的:

__m256 _zd = { 17.236,19.336,72.35,47.391,8.354,9.336 };        Single precision --- floats 32 bits (1 signed 8 exponent 23 mantissa)
__asm nop;
float(*ArrPtr)[8] = (float(*)[8])&_zd;
std::cout << *(*ArrPtr) << " Extracted values " << *((*ArrPtr)+1) << std::end

l;