(a*b)%m分割操作数

(a*b)%m splitting the operands

本文关键字:分割 操作数      更新时间:2023-10-16

我正在执行(a*b)%m,但是所有a,b和m都是10^18阶。

a*b甚至(a%m *b %m)都不可能相乘,因为m也是10^18阶的。

我认为没有必要将它们转换成字符串,然后将它们相乘为字符串,取mod,然后返回为long long int。

我得到了这个问题,那里公认的解决方案是拆分我的操作数。(请参阅链接文章)。然而,我不明白他所说的位移位的解释。

我写了一个函数来计算a和b的乘积取模。

/*
   a=a1+k*a2
   b=b1+k*b2
   (a1+k*a2)*(b1+k*b2) % c = a1*b1 % c + k*a1*b2 % c + k*a2*b1 % c + k*k*a2*b2 % c
*/
ull MAM(ull a,ull b,ull mod)//multiply and mod;ull: unsigned long long
{
   ull a1,a2,b1,b2;
   ull k=4294967296; //2^32
   a1=a%k;
   a2=a/k;
   b1=b%k;
   b2=b/k;
   ull ans = (a1*b1)%mod + (((a1*k)%mod)*b2) %mod + (((k*a2)%mod)*b1)%mod + (((k*k)%mod)*((a2*b2)%mod))%mod;
   return ans;
}

但是如果没有转换,这是行不通的。谁来解释一下答案所说的位移位

With:

(((a1 * k) % mod) * b2) % mod

mod可以大于2 ** 32,因此(((a1 * k) % mod) * b2)可以溢出。

as k == 2 ** 32,所以(k * a1 * b2) % mod((a1 * b2) % mod) * (k % mod) % mod(外部乘法仍然可能溢出,所以在2*2*...*2中拆分k)

(((((a1 * b2) % mod) * 2) % mod) * 2) % mod)...
     ^^^^^^^^^^^^^^
     named x
  • 如果x < 2 ** 63,那么2 * x没有溢出,我们可以进行(2 * x) % mod的迭代。
  • 如果x >= 2 ** 63然后2*x溢出,但我们有2 ** 63 <= mod(我们有x < 2 ** 64),所以(2 * x) % mod可以计算为2 * x - mod或写入没有溢出的x - (mod - x)

代码变成了

const std::uint64 limit = 0x1000000000000000;
std::uint64_t x = (a1 * b2) % mod;
for (int i = 0; i != 32; ++i) {
    if (x < limit) {
        x = (2 * x) % mod; // No overflow
    } else {
        x -= mod - x;      // Manage overflow
    }
}

同样适用于a2 * b1 * ka2 * b2 * k * k

与此等价的位移是a1 = a &0 xffffffffull;A2 = a>>32;或者a1 = (a <<32)>> 32;A2 = a>> 32;

然而,示例代码有一个问题:k*k = 0(溢出),第二和第三项:(((a1 *k) % mod) * b2)和((a2 *k) % mod) * b1)也可以溢出(mod可以是2^64-1)。似乎乘法的实现方式类似于之前发布的答案之一,但在这种情况下,不需要拆分操作数。

uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m) {
    uint64_t res = 0;
    uint64_t temp_b;
    if (a >= m)
        a %= m;
    if (b >= m)
        b %= m;
    while (a != 0) {
        if (a & 1) {
            if (b >= m - res) /* Equiv to if (res + b >= m), without overflow */
                res -= m;
            res += b;
        }
        a >>= 1;
        /* Double b, modulo m */
        temp_b = b;
        if (b >= m - b)       /* Equiv to if (2 * b >= m), without overflow */
            temp_b -= m;
        b += temp_b;
    }
    return res;
}

让我解释一下A*B%k -

    Let us assume A = a1a2a3.......an 
    and           B = b1b2b3.......bn
    where ai & bi are numeric digits
    Then 
A*B%k = A*(b1*(pow(2,n-1))%k + A*(b2*(pow(2,n-2))%k + .......... A*(bn*(pow(2,n-n)%k.
OR 
A*B%k = B*(a1*(pow(2,n-1))%k + B*(a2*(pow(2,n-2))%k + .......... B*(an*(pow(2,n-n)%k.

右边的每一项在取模前都不会超过2^63。这样做是安全的,不会有溢出

ah = a >> 1al = a & 1a = (ah << 1) + al = 2*ah + ala*b = 2*ah*b + al*b = ah*b + ah*b + al*b

所以我们可以递归地计算a*b % mod,每个a右移一个,直到a为零:

ull mulMod(ull a, ull b, ull mod) { // assuming a < mod and b < mod
   if (a == 0)
       return 0;
   ull ah = a >> 1;
   ull al = a & 1;
   ull ahb = mulMod(ah, b, mod);
   ull ahb2 = ahb < mod - ahb ? ahb + ahb : ahb - (mod - ahb);
   ull alb = al * b;
   return alb < mod - ahb2 ? ahb2 + alb : ahb2 - (mod - alb);
}

这里我们只需要关心加法模数mod。如果我们注意到(x + y) % modx < mod - y时只是x + y,否则是x - (mod - y),我们可以避免溢出。