Skip to content

Commit e3335c0

Browse files
committed
add factorial mod p^e
1 parent 13a04bd commit e3335c0

File tree

2 files changed

+180
-14
lines changed

2 files changed

+180
-14
lines changed

mathematics/basic.hpp

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,33 @@
11
#pragma once
22
#include <cassert>
3+
#include <algorithm>
34

45
using int64 = long long;
6+
using uint64 = unsigned long long;
7+
using int128 = __int128_t;
8+
using uint128 = __uint128_t;
9+
using float80 = long double;
510

6-
// mod should be not greater than 4e18
11+
// return a % b
12+
inline uint64 mod128_64_small(uint128 a, uint64 b) {
13+
uint64 q, r;
14+
__asm__ (
15+
"divq\t%4"
16+
: "=a"(q), "=d"(r)
17+
: "0"(uint64(a)), "1"(uint64(a >> 64)), "rm"(b)
18+
);
19+
return r;
20+
}
21+
22+
// mod should be not greater than 2^62 (about 4e18)
723
// return a * b % mod when mod is less than 2^31
8-
inline int64 mul_mod(int64 a, int64 b, int64 mod) {
24+
inline uint64 mul_mod(uint64 a, uint64 b, uint64 mod) {
925
assert(0 <= a && a < mod);
1026
assert(0 <= b && b < mod);
1127
if (mod < int(1e9)) return a * b % mod;
12-
int64 k = (int64)((long double)a * b / mod);
13-
int64 res = a * b - k * mod;
14-
res %= mod;
15-
if (res < 0) res += mod;
28+
uint64 k = (uint64)((long double)a * b / mod);
29+
uint64 res = a * b - k * mod;
30+
if ((int64)res < 0) res += mod;
1631
return res;
1732
}
1833

@@ -24,6 +39,10 @@ inline int64 sub_mod(int64 x, int64 y, int64 mod) {
2439
return (x - y + mod) % mod;
2540
}
2641

42+
inline uint64 mul_add_mod(uint64 a, uint64 b, uint64 c, uint64 mod) {
43+
return mod128_64_small(uint128(a) * b + c, mod);
44+
}
45+
2746
int64 pow_mod(int64 a, int64 n, int64 m) {
2847
int64 res = 1;
2948
for (a %= m; n; n >>= 1) {
@@ -47,11 +66,21 @@ void exgcd(int64 a, int64 b, int64 &g, int64 &x, int64 &y) {
4766
}
4867
}
4968

50-
// ax = 1 (mod m), gcd(a, m) = 1
51-
int64 mod_inv(int64 a, int64 m) {
52-
int64 d, x, y;
53-
exgcd(a, m, d, x, y);
54-
return d == 1 ? (x % m + m) % m : -1;
69+
// return x, where ax = 1 (mod mod)
70+
int64 mod_inv(int64 a, int64 mod) {
71+
if (gcd(a, mod) != 1) return -1;
72+
int64 b = mod, s = 1, t = 0;
73+
while (b) {
74+
int64 q = a / b;
75+
std::swap(a -= q * b, b);
76+
std::swap(s -= q * t, t);
77+
}
78+
return s < 0 ? s + mod : s;
79+
}
80+
81+
uint64 crt2(uint64 r1, uint64 mod1, uint64 r2, uint64 mod2) {
82+
uint64 inv = mod_inv(mod1, mod2);
83+
return mul_mod(r2 + mod2 - r1, inv, mod2) * mod1 + r1;
5584
}
5685

5786
//ax + by = c,x >= 0, x is minimum
@@ -67,9 +96,9 @@ bool linear_equation(int64 a, int64 b, int64 c, int64 &x, int64 &y) {
6796
}
6897

6998
// 求n的欧拉函数值,简易版
70-
int euler_phi(int n) {
71-
int ret = n;
72-
for (int i = 2; i * i <= n; ++i) if (n % i == 0) {
99+
int64 euler_phi(int64 n) {
100+
int64 ret = n;
101+
for (int64 i = 2; i * i <= n; ++i) if (n % i == 0) {
73102
ret = ret / i * (i - 1);
74103
while (n % i == 0) n /= i;
75104
}

mathematics/factorial_p.cc

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#include "basic.hpp"
2+
#include <cstdio>
3+
#include <cmath>
4+
#include <vector>
5+
#include <utility>
6+
7+
using poly_t = std::vector<int64>;
8+
9+
// n! = a * p ^ b, p <= 1e12
10+
// v = sqrt(n)
11+
// f(x) = \prod_{i=1}^{v} (x + i)
12+
// n! = \prod_{i=0}^{v-1} f(i) \prod_{i=v^2+1}^{n} i
13+
std::pair<int64, int64> fact_p(int64 n, int64 p) {
14+
15+
auto shift_evaluation = [&mod = p] (const poly_t &P, int64 a) {
16+
// f(i) --> f(i + a), {i, 0, d}
17+
const int d = P.size() - 1;
18+
};
19+
20+
int64 a = pow_mod(p - 1, n / p, p), b = n / p;
21+
for (int64 x = p; x <= n / p; x *= p) {
22+
b += n / (x * p);
23+
}
24+
if ((n %= p) == 0) return {a, b};
25+
const int64 v = static_cast<int64>(sqrt(n));
26+
for (int64 i = v * v + 1; i <= n; ++i) {
27+
a = mul_mod(a, i, p);
28+
}
29+
return {a, b};
30+
}
31+
32+
// n! / p^{v_p(n!)} mod p^e, assume p^e < 2^63 - 1, pe < 10^6
33+
// (n!)_p = \stirlingfirst{p}{1}^u f_{p,e}(u) \sum_{k=0}^{e-1} (up)^k \stirlingfirst{v+1}{k+1} \bmod p^e
34+
// f_{p,e} = \prod_{i=0}^{x-1}(1 + \sum_{k=1}^{e-1}\frac{\stirlingfirst{p}{k+1}}{\stirlingfirst{p}{1}} (ip)^k)
35+
uint64 fact_pe(uint64 n, uint64 p, uint64 e) {
36+
std::vector<uint64> pows(e + 1, 1);
37+
uint64 pe = 1, min_pe = std::min(p, e);
38+
for (uint64 i = 1; i <= e; ++i) {
39+
pows[i] = (pe *= p);
40+
}
41+
uint64 period = pe / p * 2, deg = e * 2 - 1;
42+
if (p == 2 && e >= 3) period >>= 1;
43+
44+
// first kind stirling number: O(p * min(p, e))
45+
std::vector<uint64> s1(p * min_pe); s1[0] = 1;
46+
for (uint64 i = 1; i < p; ++i) {
47+
int o = i * min_pe;
48+
s1[o] = (uint128)s1[o - min_pe] * i % pe;
49+
for (uint64 j = 1; j < min_pe; ++j) {
50+
s1[o + j] = (s1[o + j - min_pe - 1] + (uint128)s1[o + j - min_pe] * i) % pe;
51+
}
52+
}
53+
54+
// product of {up + 1, ..., up + v} mod p^e
55+
auto fact_range = [&] (uint64 u, uint64 v) {
56+
uint64 coef = (uint128)u % pe * p %pe, prod = 1, ret = 0;
57+
for (uint64 k = 0; k < min_pe; ++k) {
58+
ret = (ret + (uint128)prod * s1[v * min_pe + k]) % pe;
59+
prod = (uint128)prod * coef % pe;
60+
}
61+
return ret;
62+
};
63+
64+
// f_{p,e}(0..2e-2): O(e * min(p, e) + e log(p))
65+
uint64 fac = fact_range(0, p - 1), ifac = mod_inv(fac, pe);
66+
std::vector<uint64> f_pe(deg, 1);
67+
for (uint64 i = 1; i < deg; ++i) {
68+
f_pe[i] = (uint128)f_pe[i - 1] * fact_range(i - 1, p - 1) % pe * ifac % pe;
69+
}
70+
71+
// coprime factorials: O(e + e log(p))
72+
std::vector<uint64> cifac(deg, 1), cfac_vs(deg);
73+
uint64 prod = 1;
74+
for (uint64 i = 1; i < deg; ++i) {
75+
uint64 j = i, v = 0;
76+
for (; j % p == 0; j /= p, ++v);
77+
cfac_vs[i] = cfac_vs[i - 1] + v;
78+
cifac[i - 1] = j;
79+
prod = (uint128)prod * j % pe;
80+
}
81+
cifac[deg - 1] = mod_inv(prod, pe);
82+
for (int i = deg - 2; i >= 0; --i) {
83+
cifac[i] = (uint128)cifac[i + 1] * cifac[i] % pe;
84+
}
85+
86+
// find the value of f_{p, e}(x): O(e log x)
87+
auto evaluate = [&](uint64 x) {
88+
if (x < (uint64)deg) return f_pe[x];
89+
std::vector<uint64> vs(deg), inv(deg);
90+
uint64 v = 0, prod = 1;
91+
for (uint64 i = 0; i < deg; ++i) {
92+
uint64 m = x - i;
93+
for (; m % p == 0; m /= p, ++vs[i]);
94+
v += vs[i];
95+
inv[i] = prod;
96+
prod = (uint128)prod * m % pe;
97+
}
98+
uint64 iprod = mod_inv(prod, pe);
99+
for (int i = deg - 1; i >= 0; --i) {
100+
inv[i] = (uint128)iprod * inv[i] % pe;
101+
iprod = (uint128)iprod * ((x - i) / pows[vs[i]]) % pe;
102+
}
103+
uint64 ret = 0;
104+
for (uint64 i = 0; i < deg; ++i) {
105+
uint64 j = deg - 1 - i, ex = v - vs[i] - cfac_vs[i] - cfac_vs[j];
106+
if (ex >= e) continue;
107+
uint64 add = (uint128)cifac[j] * cifac[i] % pe;
108+
if (j & 1) add = pe - add;
109+
add = (uint128)pows[ex] * prod % pe * inv[i] % pe * add % pe * f_pe[i] % pe;
110+
ret = (ret + add) % pe;
111+
}
112+
return ret;
113+
};
114+
115+
// ((up+v)!)_p mod p^e: O(min(p, e))
116+
auto fact_p = [&](uint64 u, uint64 v) {
117+
return (uint128)fact_range(u, v) * evaluate(u) % pe;
118+
};
119+
120+
uint64 ret = 1, ex = 0;
121+
while (n > 0) {
122+
uint64 q = n / p, v = n % p;
123+
uint64 u = q % period;
124+
ret = (uint128)ret * fact_p(u, v) % pe;
125+
ex += u, n = q;
126+
}
127+
for (ex %= period; ex; ex >>= 1) {
128+
if (ex & 1) ret = (uint128)ret * fac % pe;
129+
fac = (uint128)fac * fac % pe;
130+
}
131+
return ret;
132+
}
133+
134+
int main() {
135+
uint64 p = 5, e = 3, n = 63;
136+
printf("%llu\n", fact_pe(n, p, e));
137+
}

0 commit comments

Comments
 (0)