Skip to content

Commit 908c897

Browse files
committed
add fast mod int
1 parent 51d094c commit 908c897

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

mathematics/mod_int.hpp

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#include <cassert>
2+
3+
using uint32 = unsigned int;
4+
using int64 = long long;
5+
using uint64 = unsigned long long;
6+
using uint128 = __uint128_t;
7+
8+
// Montgomery modular multiplication -- about 7x faster
9+
// ensure mod is an odd number, use after call `set_mod` method
10+
template<typename word, typename dword, typename sword>
11+
struct UnsafeMod {
12+
UnsafeMod(): x(0) {}
13+
UnsafeMod(word _x): x(init(_x)) {}
14+
15+
UnsafeMod& operator += (const UnsafeMod& rhs) {
16+
if ((x += rhs.x) >= mod) x -= mod;
17+
return *this;
18+
}
19+
UnsafeMod& operator -= (const UnsafeMod& rhs) {
20+
if (sword(x -= rhs.x) < 0) x += mod;
21+
return *this;
22+
}
23+
UnsafeMod& operator *= (const UnsafeMod& rhs) {
24+
x = reduce(dword(x) * rhs.x);
25+
return *this;
26+
}
27+
UnsafeMod operator + (const UnsafeMod &rhs) const {
28+
return UnsafeMod(*this) += rhs;
29+
}
30+
UnsafeMod operator - (const UnsafeMod &rhs) const {
31+
return UnsafeMod(*this) -= rhs;
32+
}
33+
UnsafeMod operator * (const UnsafeMod &rhs) const {
34+
return UnsafeMod(*this) *= rhs;
35+
}
36+
UnsafeMod pow(uint64 e) const {
37+
UnsafeMod ret(1);
38+
for (UnsafeMod base = *this; e; e >>= 1, base *= base) {
39+
if (e & 1) ret *= base;
40+
}
41+
return ret;
42+
}
43+
word get() const {
44+
return reduce(x);
45+
}
46+
47+
static constexpr int word_bits = sizeof(word) * 8;
48+
static word modulus() {
49+
return mod;
50+
}
51+
static word init(word w) {
52+
return reduce(dword(w) * r2);
53+
}
54+
static void set_mod(word m) {
55+
mod = m;
56+
inv = mul_inv(mod);
57+
r2 = -dword(mod) % mod;
58+
}
59+
static word reduce(dword x) {
60+
word y = word(x >> word_bits) - word((dword(word(x) * inv) * mod) >> word_bits);
61+
return sword(y) < 0 ? y + mod : y;
62+
}
63+
static word mul_inv(word n, int e = 6, word x = 1) {
64+
return !e ? x : mul_inv(n, e - 1, x * (2 - x * n));
65+
}
66+
static word mod, inv, r2;
67+
68+
word x;
69+
};
70+
71+
using UnsafeMod64 = UnsafeMod<uint64, uint128, int64>;
72+
using UnsafeMod32 = UnsafeMod<uint32, uint64, int>;
73+
template <> uint64 UnsafeMod64::mod = 0;
74+
template <> uint64 UnsafeMod64::inv = 0;
75+
template <> uint64 UnsafeMod64::r2 = 0;
76+
template <> uint32 UnsafeMod32::mod = 0;
77+
template <> uint32 UnsafeMod32::inv = 0;
78+
template <> uint32 UnsafeMod32::r2 = 0;
79+
80+
// in this version mod can be any positive number
81+
// speed is about 5x than usual mod
82+
template<typename word, typename dword, typename sword>
83+
struct Mod {
84+
Mod() : x(0) {}
85+
Mod(word _x) : x(init(_x)) {}
86+
87+
Mod& operator += (const Mod& rhs) {
88+
word hi = (x >> shift) + (rhs.x >> shift) - mod;
89+
if (sword(hi) < 0) hi += mod;
90+
x = hi << shift | ((x + rhs.x) & mask);
91+
return *this;
92+
}
93+
Mod& operator -= (const Mod& rhs) {
94+
word hi = (x >> shift) - (rhs.x >> shift);
95+
if (sword(hi) < 0) hi += mod;
96+
x = hi << shift | ((x - rhs.x) & mask);
97+
return *this;
98+
}
99+
Mod& operator *= (const Mod& rhs) {
100+
x = reduce(x, rhs.x);
101+
return *this;
102+
}
103+
Mod operator + (const Mod& rhs) const {
104+
return Mod(*this) += rhs;
105+
}
106+
Mod operator - (const Mod& rhs) const {
107+
return Mod(*this) -= rhs;
108+
}
109+
Mod operator * (const Mod& rhs) const {
110+
return Mod(*this) *= rhs;
111+
}
112+
word get() const {
113+
word ret = reduce(x, one);
114+
word r1 = ret >> shift;
115+
return mod * (((ret - r1) * inv) & mask) + r1;
116+
}
117+
Mod pow(uint64 e) const {
118+
Mod ret = Mod(1);
119+
for (Mod base = *this; e; e >>= 1, base *= base) {
120+
if (e & 1) ret *= base;
121+
}
122+
return ret;
123+
}
124+
125+
static constexpr int word_bits = sizeof(word) * 8;
126+
static void set_mod(word m) {
127+
shift = __builtin_ctzll(m);
128+
mask = (word(1) << shift) - 1;
129+
mod = m >> shift;
130+
inv = mul_inv(mod);
131+
assert(mod * inv == 1);
132+
r2 = -dword(mod) % mod;
133+
one = word(1) << shift | 1;
134+
}
135+
static word modulus() {
136+
return mod << shift;
137+
}
138+
static word init(word x) {
139+
return reduce_odd(dword(x) * r2) << shift | (x & mask);
140+
}
141+
static word reduce_odd(dword x) {
142+
word y = word(x >> word_bits) - word((dword(word(x) * inv) * mod) >> word_bits);
143+
return sword(y) < 0 ? y + mod : y;
144+
}
145+
static word reduce(word x0, word x1) {
146+
word y = reduce_odd(dword(x0 >> shift) * (x1 >> shift));
147+
return y << shift | ((x0 * x1) & mask);
148+
}
149+
static word mul_inv(word n, int e = 6, word x = 1) {
150+
return !e ? x : mul_inv(n, e - 1, x * (2 - x * n));
151+
}
152+
static word mod, inv, r2, mask, one;
153+
static int shift;
154+
word x;
155+
};
156+
157+
using Mod64 = Mod<uint64, uint128, int64>;
158+
using Mod32 = Mod<uint32, uint64, int>;
159+
template <> uint64 Mod64::mod = 0;
160+
template <> uint64 Mod64::inv = 0;
161+
template <> uint64 Mod64::r2 = 0;
162+
template <> uint64 Mod64::mask = 0;
163+
template <> uint64 Mod64::one = 0;
164+
template <> int Mod64::shift = 0;
165+
template <> uint32 Mod32::mod = 0;
166+
template <> uint32 Mod32::inv = 0;
167+
template <> uint32 Mod32::r2 = 0;
168+
template <> uint32 Mod32::mask = 0;
169+
template <> uint32 Mod32::one = 0;
170+
template <> int Mod32::shift = 0;

0 commit comments

Comments
 (0)