在 x86-64 平台上为 C(++) 中的 64 位无符号参数计算 (a*b)%n FAST

Compute (a*b)%n FAST for 64-bit unsigned arguments in C(++) on x86-64 platforms?

本文关键字:计算 参数 无符号 FAST 中的 平台 x86-64      更新时间:2023-10-16

我正在寻找一种快速方法来有效地计算(ab(模n(在数学意义上(abn类型为uint64_t。我可以忍受诸如n!=0甚至a<n && b<n等先决条件。

请注意,C 表达式(a*b)%n不会剪切它,因为乘积被截断为 64 位。我正在寻找(uint64_t)(((uint128_t)a*b)%n)除了我没有uint128_t(我知道,在视觉C++中(。

我想要一个视觉C++(最好(或GCC/clang内在的,充分利用x86-64平台上可用的底层硬件;或者如果便携式inline功能无法做到这一点。

好的,这个怎么样(未测试(

modmul:
; rcx = a
; rdx = b
; r8 = n
mov rax, rdx
mul rcx
div r8
mov rax, rdx
ret

前提是a * b / n <= ~0ULL,否则会有除法错误。这是一个比a < n && m < n稍微不那么严格的条件,其中一个可以大于n,只要另一个足够小。

不幸的是,它必须单独组装和链接,因为 MSVC 不支持 64 位目标的内联 asm。

它仍然很慢,真正的问题是 64位div ,这可能需要近一百个周期(说真的,例如在 Nehalem 上最多 90 个周期(。

你可以用老式的方式用shift/add/subtract来做。 下面的代码假设a <n和>
n <263(所以事情不会溢出(:

uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
    uint64_t rv = 0;
    while (b) {
        if (b&1)
            if ((rv += a) >= n) rv -= n;
        if ((a += a) >= n) a -= n;
        b >>= 1; }
    return rv;
}

您可以使用while (a && b)进行循环,如果a可能是n因素,则可以短路。 如果a不是n的因素,则会稍慢(更多比较和可能正确预测的分支(。

如果你真的,绝对,需要最后一位(允许n最多264-1(,你可以使用:

uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
    uint64_t rv = 0;
    while (b) {
        if (b&1) {
            rv += a;
            if (rv < a || rv >= n) rv -= n; }
        uint64_t t = a;
        a += a;
        if (a < t || a >= n) a -= n;
        b >>= 1; }
    return rv;
}

或者,只需使用 GCC intriinsics 访问底层 x64 指令:

inline uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
    uint64_t rv;
    asm ("mul %3" : "=d"(rv), "=a"(a) : "1"(a), "r"(b));
    asm ("div %4" : "=d"(rv), "=a"(a) : "0"(rv), "1"(a), "r"(n));
    return rv;
}

然而,64 位div 指令真的很慢,所以循环实际上可能更快。 您需要分析才能确定。

7年后,我得到了一个在Visual Studio 2019中工作的解决方案

#include <stdint.h>
#include <intrin.h>
#pragma intrinsic(_umul128)
#pragma intrinsic(_udiv128)
// compute (a*b)%n with 128-bit intermediary result
// assumes n>0  and  a*b < n * 2**64 (always the case when a<=n || b<=n )
inline uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
  uint64_t r, s = _umul128(a, b, &r);
  (void)_udiv128(r, s, n, &r);
  return r;
}
// compute (a*b)%n with 128-bit intermediary result
// assumes n>0, works including if a*b >= n * 2**64
inline uint64_t mulmod1(uint64_t a, uint64_t b, uint64_t n) {
  uint64_t r, s = _umul128(a % n, b, &r);
  (void)_udiv128(r, s, n, &r);
  return r;
}

这个内在函数被命名为__mul128

typedef unsigned long long BIG;
// handles only the "hard" case when high bit of n is set
BIG shl_mod( BIG v, BIG n, int by )
{
    if (v > n) v -= n;
    while (by--) {
        if (v > (n-v))
            v -= n-v;
        else
            v <<= 1;
    }
    return v;
}

现在您可以使用shl_mod(B, n, 64)

没有内联程序集有点糟糕。无论如何,函数调用开销实际上非常小。参数在易失寄存器中传递,无需清理。

我没有汇编程序,x64 目标不支持__asm,所以我别无选择,只能自己从操作码"组装"我的函数。

显然这取决于.我使用 mpir (gmp( 作为参考来显示函数产生正确的结果。


#include "stdafx.h"
// mulmod64(a, b, m) == (a * b) % m
typedef uint64_t(__cdecl *mulmod64_fnptr_t)(uint64_t a, uint64_t b, uint64_t m);
uint8_t mulmod64_opcodes[] = {
    0x48, 0x89, 0xC8, // mov rax, rcx
    0x48, 0xF7, 0xE2, // mul rdx
    0x4C, 0x89, 0xC1, // mov rcx, r8
    0x48, 0xF7, 0xF1, // div rcx
    0x48, 0x89, 0xD0, // mov rax,rdx
    0xC3              // ret
};
mulmod64_fnptr_t mulmod64_fnptr;
void init() {
    DWORD dwOldProtect;
    VirtualProtect(
        &mulmod64_opcodes,
        sizeof(mulmod64_opcodes),
        PAGE_EXECUTE_READWRITE,
        &dwOldProtect);
    // NOTE: reinterpret byte array as a function pointer
    mulmod64_fnptr = (mulmod64_fnptr_t)(void*)mulmod64_opcodes;
}
int main() {
    init();
    uint64_t a64 = 2139018971924123ull;
    uint64_t b64 = 1239485798578921ull;
    uint64_t m64 = 8975489368910167ull;
    // reference code
    mpz_t a, b, c, m, r;
    mpz_inits(a, b, c, m, r, NULL);
    mpz_set_ui(a, a64);
    mpz_set_ui(b, b64);
    mpz_set_ui(m, m64);
    mpz_mul(c, a, b);
    mpz_mod(r, c, m);
    gmp_printf("(%Zd * %Zd) mod %Zd = %Zdn", a, b, m, r);
    // using mulmod64
    uint64_t r64 = mulmod64_fnptr(a64, b64, m64);
    printf("(%llu * %llu) mod %llu = %llun", a64, b64, m64, r64);
    return 0;
}