Skip to content

Commit c410784

Browse files
committed
add fast miller-robin and prime factorization
1 parent 0acd4cd commit c410784

File tree

3 files changed

+227
-58
lines changed

3 files changed

+227
-58
lines changed

mathematics/Fibonacci.cc

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
#include "basic.hpp"
2+
#include <functional>
23

3-
void fib(int64 n, int64 &x, int64 &y) {// store in x, n-th
4-
if (n == 1) {
5-
x = y = 1;
6-
return;
7-
} else if (n & 1) {
8-
fib(n - 1, y, x);
9-
y += x;
4+
// f(0) = 0, f(1) = 1, f(n) = f(n - 1) + f(n - 2)
5+
std::pair<int64, int64> fib(int64 n, int64 mod) {
6+
if (n == 0) return {0, 1};
7+
int64 x, y;
8+
if (n & 1) {
9+
std::tie(y, x) = fib(n - 1, mod);
10+
return {x, (y + x) % mod};
1011
} else {
11-
int64 a, b;
12-
fib(n >> 1, a, b);
13-
y = a * a + b * b;
14-
x = a * b + a * (b - a);
12+
std::tie(x, y) = fib(n >> 1, mod);
13+
return {(x * y + x * (y - x + mod)) % mod, (x * x + y * y) % mod};
1514
}
1615
}

mathematics/mod_int.hpp renamed to mathematics/ModInt.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ struct UnsafeMod {
1212
UnsafeMod(): x(0) {}
1313
UnsafeMod(word _x): x(init(_x)) {}
1414

15+
bool operator == (const UnsafeMod& rhs) const {
16+
return x == rhs.x;
17+
}
18+
bool operator != (const UnsafeMod& rhs) const {
19+
return x != rhs.x;
20+
}
1521
UnsafeMod& operator += (const UnsafeMod& rhs) {
1622
if ((x += rhs.x) >= mod) x -= mod;
1723
return *this;

mathematics/prime.cc

Lines changed: 211 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,220 @@
1-
#include "basic.hpp"
1+
#include <cmath>
2+
#include <numeric>
3+
#include <cassert>
4+
#include <cstdio>
25
#include <vector>
36
#include <cstdlib>
47
#include <algorithm>
58

6-
struct Primality {
7-
public:
8-
// 用miller rabin素数测试判断n是否为质数
9-
bool is_prime(int64 n) {
10-
if (n <= 1) return false;
11-
if (n <= 3) return true;
12-
if (~n & 1) return false;
13-
const int u[] = {2,3,5,7,325,9375,28178,450775,9780504,1795265022,0};
14-
int64 e = n - 1, a, c = 0; // 原理:http://miller-rabin.appspot.com/
15-
while (~e & 1) e >>= 1, ++c;
16-
for (int i = 0; u[i]; ++i) {
17-
if (n <= u[i]) return true;
18-
a = pow_mod(u[i], e, n);
19-
if (a == 1) continue;
20-
for (int j = 1; a != n - 1; ++j) {
21-
if (j == c) return false;
22-
a = mul_mod(a, a, n);
9+
namespace prime {
10+
11+
using uint128 = __uint128_t;
12+
using uint64 = unsigned long long;
13+
using int64 = long long;
14+
using uint32 = unsigned int;
15+
using pii = std::pair<uint64, uint32>;
16+
17+
inline uint64 sqr(uint64 x) { return x * x; }
18+
inline uint32 isqrt(uint64 x) { return sqrtl(x); }
19+
inline uint32 ctz(uint64 x) { return __builtin_ctzll(x); }
20+
21+
template <typename word>
22+
word gcd(word a, word b) {
23+
while (b) { word t = a % b; a = b; b = t; }
24+
return a;
25+
}
26+
27+
template <typename word, typename dword, typename sword>
28+
struct Mod {
29+
Mod(): x(0) {}
30+
Mod(word _x): x(init(_x)) {}
31+
bool operator == (const Mod& rhs) const { return x == rhs.x; }
32+
bool operator != (const Mod& rhs) const { return x != rhs.x; }
33+
Mod& operator += (const Mod& rhs) { if ((x += rhs.x) >= mod) x -= mod; return *this; }
34+
Mod& operator -= (const Mod& rhs) { if (sword(x -= rhs.x) < 0) x += mod; return *this; }
35+
Mod& operator *= (const Mod& rhs) { x = reduce(dword(x) * rhs.x); return *this; }
36+
Mod operator + (const Mod &rhs) const { return Mod(*this) += rhs; }
37+
Mod operator - (const Mod &rhs) const { return Mod(*this) -= rhs; }
38+
Mod operator * (const Mod &rhs) const { return Mod(*this) *= rhs; }
39+
Mod operator - () const { return Mod() - *this; }
40+
Mod pow(uint64 e) const {
41+
Mod ret(1);
42+
for (Mod base = *this; e; e >>= 1, base *= base) {
43+
if (e & 1) ret *= base;
44+
}
45+
return ret;
46+
}
47+
word get() const { return reduce(x); }
48+
static constexpr int word_bits = sizeof(word) * 8;
49+
static word modulus() { return mod; }
50+
static word init(word w) { return reduce(dword(w) * r2); }
51+
static void set_mod(word m) { mod = m, inv = mul_inv(mod), r2 = -dword(mod) % mod; }
52+
static word reduce(dword x) {
53+
word y = word(x >> word_bits) - word((dword(word(x) * inv) * mod) >> word_bits);
54+
return sword(y) < 0 ? y + mod : y;
55+
}
56+
static word mul_inv(word n, int e = 6, word x = 1) {
57+
return !e ? x : mul_inv(n, e - 1, x * (2 - x * n));
58+
}
59+
static word mod, inv, r2;
60+
61+
word x;
62+
};
63+
64+
using Mod64 = Mod<uint64, uint128, int64>;
65+
using Mod32 = Mod<uint32, uint64, int>;
66+
template <> uint64 Mod64::mod = 0;
67+
template <> uint64 Mod64::inv = 0;
68+
template <> uint64 Mod64::r2 = 0;
69+
template <> uint32 Mod32::mod = 0;
70+
template <> uint32 Mod32::inv = 0;
71+
template <> uint32 Mod32::r2 = 0;
72+
73+
template <class word, class mod>
74+
bool composite(word n, const uint32* bases, int m) {
75+
mod::set_mod(n);
76+
int s = __builtin_ctzll(n - 1);
77+
word d = (n - 1) >> s;
78+
mod one{1}, minus_one{n - 1};
79+
for (int i = 0, j; i < m; ++i) {
80+
mod a = mod(bases[i]).pow(d);
81+
if (a == one || a == minus_one) continue;
82+
for (j = s - 1; j > 0; --j) {
83+
if ((a *= a) == minus_one) break;
84+
}
85+
if (j == 0) return true;
86+
}
87+
return false;
88+
}
89+
90+
bool is_prime(uint64 n) { // reference: http://miller-rabin.appspot.com
91+
assert(n < (uint64(1) << 63));
92+
static const uint32 bases[][7] = {
93+
{2, 3},
94+
{2, 299417},
95+
{2, 7, 61},
96+
{15, 176006322, uint32(4221622697)},
97+
{2, 2570940, 211991001, uint32(3749873356)},
98+
{2, 2570940, 880937, 610386380, uint32(4130785767)},
99+
{2, 325, 9375, 28178, 450775, 9780504, 1795265022}
100+
};
101+
if (n <= 1) return false;
102+
if (!(n & 1)) return n == 2;
103+
if (n <= 8) return true;
104+
int x = 6, y = 7;
105+
if (n < 1373653) x = 0, y = 2;
106+
else if (n < 19471033) x = 1, y = 2;
107+
else if (n < 4759123141) x = 2, y = 3;
108+
else if (n < 154639673381) x = y = 3;
109+
else if (n < 47636622961201) x = y = 4;
110+
else if (n < 3770579582154547) x = y = 5;
111+
if (n < (uint32(1) << 31)) {
112+
return !composite<uint32, Mod32>(n, bases[x], y);
113+
} else if (n < (uint64(1) << 63)) {
114+
return !composite<uint64, Mod64>(n, bases[x], y);
115+
}
116+
return true;
117+
}
118+
119+
struct ExactDiv {
120+
ExactDiv() {}
121+
ExactDiv(uint64 n) : n(n), i(Mod64::mul_inv(n)), t(uint64(-1) / n) {}
122+
friend uint64 operator / (uint64 n, ExactDiv d) { return n * d.i; };
123+
bool divide(uint64 n) { return n / *this <= t; }
124+
uint64 n, i, t;
125+
};
126+
127+
std::vector<ExactDiv> primes;
128+
129+
void init(uint32 n) {
130+
uint32 sqrt_n = sqrt(n);
131+
std::vector<bool> is_prime(n + 1, 1);
132+
primes.clear();
133+
for (uint32 i = 2; i <= sqrt_n; ++i) if (is_prime[i]) {
134+
if (i != 2) primes.push_back(ExactDiv(i));
135+
for (uint32 j = i * i; j <= n; j += i) is_prime[j] = 0;
136+
}
137+
}
138+
139+
template <typename word, typename mod>
140+
word brent(word n, word c) { // n must be composite and odd.
141+
const uint64 s = 256;
142+
mod::set_mod(n);
143+
const mod one = mod(1), mc = mod(c);
144+
mod y = one;
145+
for (uint64 l = 1; ; l <<= 1) {
146+
auto x = y;
147+
for (int i = 0; i < (int)l; ++i) y = y * y + mc;
148+
mod p = one;
149+
for (int k = 0; k < (int)l; k += s) {
150+
auto sy = y;
151+
for (int i = 0; i < (int)std::min(s, l - k); ++i) {
152+
y = y * y + mc;
153+
p *= y - x;
154+
}
155+
word g = gcd(n, p.x);
156+
if (g == 1) continue;
157+
if (g == n) for (g = 1, y = sy; g == 1; ) {
158+
y = y * y + mc, g = gcd(n, (y - x).x);
23159
}
160+
return g;
24161
}
25-
return true;
26-
}
27-
// 求一个小于n的因数,期望复杂度为O(n^0.25),当n为非合数时返回n本身
28-
int64 pollard_rho(int64 n){
29-
if (n <= 3 || is_prime(n)) return n; // 保证n为合数时可去掉这行
30-
while (1) {
31-
int i = 1, cnt = 2;
32-
int64 x = rand() % n, y = x, c = rand() % n;
33-
if (!c || c == n - 2) ++c;
34-
do {
35-
int64 u = gcd(n - x + y, n);
36-
if (u > 1 && u < n) return u;
37-
if (++i == cnt) y = x, cnt <<= 1;
38-
x = (c + mul_mod(x, x, n)) % n;
39-
} while (x != y);
162+
}
163+
}
164+
165+
uint64 brent(uint64 n, uint64 c) {
166+
if (n < (uint32(1) << 31)) {
167+
return brent<uint32, Mod32>(n, c);
168+
} else if (n < (uint64(1) << 63)) {
169+
return brent<uint64, Mod64>(n, c);
170+
}
171+
return 0;
172+
}
173+
174+
std::vector<pii> factors(uint64 n) {
175+
assert(n < (uint64(1) << 63));
176+
if (n <= 1) return {};
177+
std::vector<pii> ret;
178+
uint32 v = sqrtl(n);
179+
if (uint64(v) * v == n) {
180+
ret = factors(v);
181+
for (auto &&e: ret) e.second *= 2;
182+
return ret;
183+
}
184+
v = cbrtl(n);
185+
if (uint64(v) * v * v == n) {
186+
ret = factors(v);
187+
for (auto &&e: ret) e.second *= 3;
188+
return ret;
189+
}
190+
if (!(n & 1)) {
191+
uint32 e = __builtin_ctzll(n);
192+
ret.emplace_back(2, e);
193+
n >>= e;
194+
}
195+
uint64 lim = sqr(primes.back().n);
196+
for (auto &&p: primes) {
197+
if (sqr(p.n) > n) break;
198+
if (p.divide(n)) {
199+
uint32 e = 1; n = n / p;
200+
while (p.divide(n)) n = n / p, e++;
201+
ret.emplace_back(p.n, e);
40202
}
41-
return n;
42-
}
43-
// 使用rho方法对n做质因数分解,建议先筛去小质因数后再用此函数
44-
std::vector<int64> factorize(int64 n){
45-
std::vector<int64> u;
46-
if (n > 1) u.push_back(n);
47-
for (size_t i = 0; i < u.size(); ++i){
48-
int64 x = pollard_rho(u[i]);
49-
if(x == u[i]) continue;
50-
u[i--] /= x;
51-
u.push_back(x);
203+
}
204+
205+
uint32 s = ret.size();
206+
while (n > lim && !is_prime(n)) {
207+
for (uint64 c = 1; ; ++c) {
208+
uint64 p = brent(n, c);
209+
if (!is_prime(p)) continue;
210+
uint32 e = 1; n /= p;
211+
while (n % p == 0) n /= p, e += 1;
212+
ret.emplace_back(p, e);
213+
break;
52214
}
53-
std::sort(u.begin(), u.end());
54-
return u;
55215
}
56-
};
216+
if (n > 1) ret.emplace_back(n, 1);
217+
if (ret.size() - s >= 2) sort(ret.begin() + s, ret.end());
218+
return ret;
219+
}
220+
}

0 commit comments

Comments
 (0)