|
| 1 | +#include <cassert> |
| 2 | + |
| 3 | +using uint32 = unsigned int; |
| 4 | +using int64 = long long; |
| 5 | +using uint64 = unsigned long long; |
| 6 | +using uint128 = __uint128_t; |
| 7 | + |
| 8 | +// Montgomery modular multiplication -- about 7x faster |
| 9 | +// ensure mod is an odd number, use after call `set_mod` method |
| 10 | +template<typename word, typename dword, typename sword> |
| 11 | +struct UnsafeMod { |
| 12 | + UnsafeMod(): x(0) {} |
| 13 | + UnsafeMod(word _x): x(init(_x)) {} |
| 14 | + |
| 15 | + UnsafeMod& operator += (const UnsafeMod& rhs) { |
| 16 | + if ((x += rhs.x) >= mod) x -= mod; |
| 17 | + return *this; |
| 18 | + } |
| 19 | + UnsafeMod& operator -= (const UnsafeMod& rhs) { |
| 20 | + if (sword(x -= rhs.x) < 0) x += mod; |
| 21 | + return *this; |
| 22 | + } |
| 23 | + UnsafeMod& operator *= (const UnsafeMod& rhs) { |
| 24 | + x = reduce(dword(x) * rhs.x); |
| 25 | + return *this; |
| 26 | + } |
| 27 | + UnsafeMod operator + (const UnsafeMod &rhs) const { |
| 28 | + return UnsafeMod(*this) += rhs; |
| 29 | + } |
| 30 | + UnsafeMod operator - (const UnsafeMod &rhs) const { |
| 31 | + return UnsafeMod(*this) -= rhs; |
| 32 | + } |
| 33 | + UnsafeMod operator * (const UnsafeMod &rhs) const { |
| 34 | + return UnsafeMod(*this) *= rhs; |
| 35 | + } |
| 36 | + UnsafeMod pow(uint64 e) const { |
| 37 | + UnsafeMod ret(1); |
| 38 | + for (UnsafeMod base = *this; e; e >>= 1, base *= base) { |
| 39 | + if (e & 1) ret *= base; |
| 40 | + } |
| 41 | + return ret; |
| 42 | + } |
| 43 | + word get() const { |
| 44 | + return reduce(x); |
| 45 | + } |
| 46 | + |
| 47 | + static constexpr int word_bits = sizeof(word) * 8; |
| 48 | + static word modulus() { |
| 49 | + return mod; |
| 50 | + } |
| 51 | + static word init(word w) { |
| 52 | + return reduce(dword(w) * r2); |
| 53 | + } |
| 54 | + static void set_mod(word m) { |
| 55 | + mod = m; |
| 56 | + inv = mul_inv(mod); |
| 57 | + r2 = -dword(mod) % mod; |
| 58 | + } |
| 59 | + static word reduce(dword x) { |
| 60 | + word y = word(x >> word_bits) - word((dword(word(x) * inv) * mod) >> word_bits); |
| 61 | + return sword(y) < 0 ? y + mod : y; |
| 62 | + } |
| 63 | + static word mul_inv(word n, int e = 6, word x = 1) { |
| 64 | + return !e ? x : mul_inv(n, e - 1, x * (2 - x * n)); |
| 65 | + } |
| 66 | + static word mod, inv, r2; |
| 67 | + |
| 68 | + word x; |
| 69 | +}; |
| 70 | + |
| 71 | +using UnsafeMod64 = UnsafeMod<uint64, uint128, int64>; |
| 72 | +using UnsafeMod32 = UnsafeMod<uint32, uint64, int>; |
| 73 | +template <> uint64 UnsafeMod64::mod = 0; |
| 74 | +template <> uint64 UnsafeMod64::inv = 0; |
| 75 | +template <> uint64 UnsafeMod64::r2 = 0; |
| 76 | +template <> uint32 UnsafeMod32::mod = 0; |
| 77 | +template <> uint32 UnsafeMod32::inv = 0; |
| 78 | +template <> uint32 UnsafeMod32::r2 = 0; |
| 79 | + |
| 80 | +// in this version mod can be any positive number |
| 81 | +// speed is about 5x than usual mod |
| 82 | +template<typename word, typename dword, typename sword> |
| 83 | +struct Mod { |
| 84 | + Mod() : x(0) {} |
| 85 | + Mod(word _x) : x(init(_x)) {} |
| 86 | + |
| 87 | + Mod& operator += (const Mod& rhs) { |
| 88 | + word hi = (x >> shift) + (rhs.x >> shift) - mod; |
| 89 | + if (sword(hi) < 0) hi += mod; |
| 90 | + x = hi << shift | ((x + rhs.x) & mask); |
| 91 | + return *this; |
| 92 | + } |
| 93 | + Mod& operator -= (const Mod& rhs) { |
| 94 | + word hi = (x >> shift) - (rhs.x >> shift); |
| 95 | + if (sword(hi) < 0) hi += mod; |
| 96 | + x = hi << shift | ((x - rhs.x) & mask); |
| 97 | + return *this; |
| 98 | + } |
| 99 | + Mod& operator *= (const Mod& rhs) { |
| 100 | + x = reduce(x, rhs.x); |
| 101 | + return *this; |
| 102 | + } |
| 103 | + Mod operator + (const Mod& rhs) const { |
| 104 | + return Mod(*this) += rhs; |
| 105 | + } |
| 106 | + Mod operator - (const Mod& rhs) const { |
| 107 | + return Mod(*this) -= rhs; |
| 108 | + } |
| 109 | + Mod operator * (const Mod& rhs) const { |
| 110 | + return Mod(*this) *= rhs; |
| 111 | + } |
| 112 | + word get() const { |
| 113 | + word ret = reduce(x, one); |
| 114 | + word r1 = ret >> shift; |
| 115 | + return mod * (((ret - r1) * inv) & mask) + r1; |
| 116 | + } |
| 117 | + Mod pow(uint64 e) const { |
| 118 | + Mod ret = Mod(1); |
| 119 | + for (Mod base = *this; e; e >>= 1, base *= base) { |
| 120 | + if (e & 1) ret *= base; |
| 121 | + } |
| 122 | + return ret; |
| 123 | + } |
| 124 | + |
| 125 | + static constexpr int word_bits = sizeof(word) * 8; |
| 126 | + static void set_mod(word m) { |
| 127 | + shift = __builtin_ctzll(m); |
| 128 | + mask = (word(1) << shift) - 1; |
| 129 | + mod = m >> shift; |
| 130 | + inv = mul_inv(mod); |
| 131 | + assert(mod * inv == 1); |
| 132 | + r2 = -dword(mod) % mod; |
| 133 | + one = word(1) << shift | 1; |
| 134 | + } |
| 135 | + static word modulus() { |
| 136 | + return mod << shift; |
| 137 | + } |
| 138 | + static word init(word x) { |
| 139 | + return reduce_odd(dword(x) * r2) << shift | (x & mask); |
| 140 | + } |
| 141 | + static word reduce_odd(dword x) { |
| 142 | + word y = word(x >> word_bits) - word((dword(word(x) * inv) * mod) >> word_bits); |
| 143 | + return sword(y) < 0 ? y + mod : y; |
| 144 | + } |
| 145 | + static word reduce(word x0, word x1) { |
| 146 | + word y = reduce_odd(dword(x0 >> shift) * (x1 >> shift)); |
| 147 | + return y << shift | ((x0 * x1) & mask); |
| 148 | + } |
| 149 | + static word mul_inv(word n, int e = 6, word x = 1) { |
| 150 | + return !e ? x : mul_inv(n, e - 1, x * (2 - x * n)); |
| 151 | + } |
| 152 | + static word mod, inv, r2, mask, one; |
| 153 | + static int shift; |
| 154 | + word x; |
| 155 | +}; |
| 156 | + |
| 157 | +using Mod64 = Mod<uint64, uint128, int64>; |
| 158 | +using Mod32 = Mod<uint32, uint64, int>; |
| 159 | +template <> uint64 Mod64::mod = 0; |
| 160 | +template <> uint64 Mod64::inv = 0; |
| 161 | +template <> uint64 Mod64::r2 = 0; |
| 162 | +template <> uint64 Mod64::mask = 0; |
| 163 | +template <> uint64 Mod64::one = 0; |
| 164 | +template <> int Mod64::shift = 0; |
| 165 | +template <> uint32 Mod32::mod = 0; |
| 166 | +template <> uint32 Mod32::inv = 0; |
| 167 | +template <> uint32 Mod32::r2 = 0; |
| 168 | +template <> uint32 Mod32::mask = 0; |
| 169 | +template <> uint32 Mod32::one = 0; |
| 170 | +template <> int Mod32::shift = 0; |
0 commit comments