Skip to content

Commit 0acd4cd

Browse files
committed
add fast Number Theoretic Transform
1 parent dcd0f52 commit 0acd4cd

File tree

2 files changed

+282
-35
lines changed

2 files changed

+282
-35
lines changed
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
#include "basic.hpp"
2+
#include <ctime>
3+
#include <cstdio>
4+
#include <cassert>
5+
#include <vector>
6+
#include <algorithm>
7+
8+
using int64 = long long;
9+
using uint32 = unsigned int;
10+
using uint64 = unsigned long long;
11+
using uint128 = __uint128_t;
12+
13+
namespace ntt {
14+
// if mod is not close to 2^(word_bits-1), it's faster to use comment lines
15+
template <class word, class dword, class sword, word mod, word root>
16+
class Mod {
17+
public:
18+
static constexpr word mul_inv(word n, int e = 6, word x = 1) {
19+
return e == 0 ? x : mul_inv(n, e - 1, x * (2 - x * n));
20+
}
21+
22+
static constexpr word inv = mul_inv(mod);
23+
static constexpr word r2 = -dword(mod) % mod;
24+
static constexpr int word_bits = sizeof(word) * 8;
25+
static constexpr int level = __builtin_ctzll(mod - 1);
26+
27+
static word modulus() {
28+
return mod;
29+
}
30+
static word init(const word& w) {
31+
return reduce(dword(w) * r2);
32+
}
33+
static word reduce(const dword& w) {
34+
word y = word(w >> word_bits) - word((dword(word(w) * inv) * mod) >> word_bits);
35+
return sword(y) < 0 ? y + mod : y;
36+
//return word(w >> word_bits) + mod - word((dword(word(w) * inv) * mod) >> word_bits);
37+
}
38+
static Mod omega() {
39+
return Mod(root).pow((mod - 1) >> level);
40+
}
41+
42+
Mod() = default;
43+
Mod(const word& n): x(init(n)) {};
44+
Mod& operator += (const Mod& rhs) {
45+
//this->x += rhs.x;
46+
if ((x += rhs.x) >= mod) x -= mod;
47+
return *this;
48+
}
49+
Mod& operator -= (const Mod& rhs) {
50+
//this->x += mod * 3 - rhs.x;
51+
if (sword(x -= rhs.x) < 0) x += mod;
52+
return *this;
53+
}
54+
Mod& operator *= (const Mod& rhs) {
55+
this->x = reduce(dword(this->x) * rhs.x);
56+
return *this;
57+
}
58+
Mod operator + (const Mod& rhs) const {
59+
return Mod(*this) += rhs;
60+
}
61+
Mod operator - (const Mod& rhs) const {
62+
return Mod(*this) -= rhs;
63+
}
64+
Mod operator * (const Mod& rhs) const {
65+
return Mod(*this) *= rhs;
66+
}
67+
word get() const {
68+
return reduce(this->x) % mod;
69+
}
70+
Mod inverse() const {
71+
return pow(mod - 2);
72+
}
73+
Mod pow(word e) const {
74+
Mod ret(1);
75+
for (Mod a = *this; e; e >>= 1) {
76+
if (e & 1) ret *= a;
77+
a *= a;
78+
}
79+
return ret;
80+
}
81+
word x;
82+
};
83+
84+
template <class T>
85+
inline void sum_diff(T& x, T &y) {
86+
auto a = x, b = y;
87+
x = a + b, y = a - b;
88+
}
89+
90+
// Matters Computational. 26.2.3.1
91+
template <class mod_t>
92+
void ntt_dit4(mod_t A[], int n, int sgn, mod_t roots[], int *rev) {
93+
for (int i = 0; i < n; ++i) {
94+
if (i < rev[i]) std::swap(A[i], A[rev[i]]);
95+
}
96+
int logn = __builtin_ctz(n);
97+
if (logn & 1) for (int i = 0; i < n; i += 2) {
98+
auto a = A[i], b = A[i + 1];
99+
A[i] = a + b, A[i + 1] = a - b;
100+
//sum_diff(A[i], A[i + 1]);
101+
}
102+
auto im = roots[mod_t::level - 2], one = mod_t(1);
103+
if (sgn < 0) im = im.inverse();
104+
for (int e = 2 + (logn & 1); e <= logn; e += 2) {
105+
const int m = 1 << e, m4 = m >> 2;
106+
auto dw = roots[mod_t::level - e];
107+
if (sgn < 0) dw = dw.inverse();
108+
const int block_size = std::min(n, std::max(m, (1 << 15) / int(sizeof(A[0]))));
109+
for (int k = 0; k < n; k += block_size) {
110+
auto w = one, w2 = one, w3 = one;
111+
for (int j = 0; j < m4; ++j) {
112+
for (int i = k + j; i < k + block_size; i += m) {
113+
auto a0 = A[i + m4 * 0] * one, a2 = A[i + m4 * 1] * w2;
114+
auto a1 = A[i + m4 * 2] * w, a3 = A[i + m4 * 3] * w3;
115+
auto t02 = a0 + a2, t13 = a1 + a3;
116+
A[i + m4 * 0] = t02 + t13; A[i + m4 * 2] = t02 - t13;
117+
t02 = a0 - a2, t13 = (a1 - a3) * im;
118+
A[i + m4 * 1] = t02 + t13; A[i + m4 * 3] = t02 - t13;
119+
}
120+
w *= dw; w2 = w * w; w3 = w2 * w;
121+
}
122+
}
123+
}
124+
}
125+
126+
// Matters Computational. 26.2.3.2
127+
template <class mod_t>
128+
void ntt_dif4(mod_t A[], int n, int sgn, mod_t roots[], int *rev) {
129+
int logn = __builtin_ctz(n);
130+
auto im = roots[mod_t::level - 2], one = mod_t(1);
131+
if (sgn < 0) im = im.inverse();
132+
for (int e = logn; e >= 2; e -= 2) {
133+
const int m = 1 << e, m4 = m >> 2;
134+
auto dw = roots[mod_t::level - e];
135+
if (sgn < 0) dw = dw.inverse();
136+
const int block_size = std::min(n, std::max(m, (1 << 15) / int(sizeof(A[0]))));
137+
for (int k = 0; k < n; k += block_size) {
138+
auto w = one, w2 = one, w3 = one;
139+
for (int j = 0; j < m4; ++j) {
140+
for (int i = k + j; i < k + block_size; i += m) {
141+
auto a0 = A[i + m4 * 0], a2 = A[i + m4 * 1];
142+
auto a1 = A[i + m4 * 2], a3 = A[i + m4 * 3];
143+
auto t02 = a0 + a2, t13 = a1 + a3;
144+
A[i + m4 * 0] = (t02 + t13) * one; A[i + m4 * 2] = (t02 - t13) * w2;
145+
t02 = a0 - a2, t13 = (a1 - a3) * im;
146+
A[i + m4 * 1] = (t02 + t13) * w; A[i + m4 * 3] = (t02 - t13) * w3;
147+
}
148+
w *= dw; w2 = w * w; w3 = w2 * w;
149+
}
150+
}
151+
}
152+
if (logn & 1) for (int i = 0; i < n; i += 2) {
153+
sum_diff(A[i], A[i + 1]);
154+
}
155+
for (int i = 0; i < n; ++i) {
156+
if (i < rev[i]) std::swap(A[i], A[rev[i]]);
157+
}
158+
}
159+
160+
template <class mod_t>
161+
void convolute(mod_t A[], int n, mod_t B[], int m, bool cyclic = false) {
162+
int s = (cyclic ? std::max(n, m) : n + m - 1), size = 1, logn = 0;
163+
while (size < s) size <<= 1, ++logn;
164+
mod_t roots[mod_t::level] = {mod_t::omega()};
165+
for (int i = 1; i < mod_t::level; ++i) {
166+
roots[i] = roots[i - 1] * roots[i - 1];
167+
}
168+
std::vector<int> rev(size);
169+
for (int i = 0; i < size; ++i) {
170+
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (logn - 1));
171+
}
172+
std::fill(A + n, A + size, 0);
173+
ntt_dit4(A, size, 1, roots, &rev[0]);
174+
if (A == B && n == m) {
175+
for (int i = 0; i < size; ++i) A[i] *= A[i];
176+
} else {
177+
std::fill(B + m, B + size, 0);
178+
ntt_dit4(B, size, 1, roots, &rev[0]);
179+
for (int i = 0; i < size; ++i) A[i] *= B[i];
180+
}
181+
ntt_dit4(A, size, -1, roots, &rev[0]);
182+
mod_t inv = mod_t(size).inverse();
183+
if (!cyclic) size = s;
184+
for (int i = 0; i < size; ++i) A[i] *= inv;
185+
}
186+
}
187+
188+
// transform with dif, itransform with dft, no need to use bit_rev
189+
namespace ntt_fast {
190+
template <typename mod_t>
191+
void transform(mod_t* A, int n, const mod_t* roots, const mod_t* iroots) {
192+
const int logn = __builtin_ctz(n), nh = n >> 1, lv = mod_t::level;
193+
auto one = mod_t(1), im = roots[lv - 2];
194+
mod_t dw[lv - 1]; dw[0] = roots[lv - 3];
195+
for (int i = 1; i < lv - 2; ++i) {
196+
dw[i] = dw[i - 1] * iroots[lv - 1 - i] * roots[lv - 3 - i];
197+
}
198+
dw[lv - 2] = dw[lv - 3] * iroots[1];
199+
if (logn & 1) for (int i = 0; i < nh; ++i) {
200+
ntt::sum_diff(A[i], A[i + nh]);
201+
}
202+
for (int e = logn & ~1; e >= 2; e -= 2) {
203+
const int m = 1 << e, m4 = m >> 2;
204+
auto w2 = one;
205+
for (int i = 0; i < n; i += m) {
206+
auto w1 = w2 * w2, w3 = w1 * w2;
207+
for (int j = i; j < i + m4; ++j) {
208+
auto a0 = A[j + m4 * 0] * one, a1 = A[j + m4 * 1] * w2;
209+
auto a2 = A[j + m4 * 2] * w1, a3 = A[j + m4 * 3] * w3;
210+
auto t02p = a0 + a2, t13p = a1 + a3;
211+
auto t02m = a0 - a2, t13m = (a1 - a3) * im;
212+
A[j + m4 * 0] = t02p + t13p; A[j + m4 * 1] = t02p - t13p;
213+
A[j + m4 * 2] = t02m + t13m; A[j + m4 * 3] = t02m - t13m;
214+
}
215+
w2 *= dw[__builtin_ctz(~(i >> e))];
216+
}
217+
}
218+
}
219+
220+
template <typename mod_t>
221+
void itransform(mod_t* A, int n, const mod_t* roots, const mod_t* iroots) {
222+
const int logn = __builtin_ctz(n), nh = n >> 1, lv = mod_t::level;
223+
const auto one = mod_t(1), im = iroots[lv - 2];
224+
mod_t dw[lv - 1]; dw[0] = iroots[lv - 3];
225+
for (int i = 1; i < lv - 2; ++i) {
226+
dw[i] = dw[i - 1] * roots[lv - 1 - i] * iroots[lv - 3 - i];
227+
}
228+
dw[lv - 2] = dw[lv - 3] * roots[1];
229+
for (int e = 2; e <= logn; e += 2) {
230+
const int m = 1 << e, m4 = m >> 2;
231+
auto w2 = one;
232+
for (int i = 0; i < n; i += m) {
233+
const auto w1 = w2 * w2, w3 = w1 * w2;
234+
for (int j = i; j < i + m4; ++j) {
235+
auto a0 = A[j + m4 * 0], a1 = A[j + m4 * 1];
236+
auto a2 = A[j + m4 * 2], a3 = A[j + m4 * 3];
237+
auto t01p = a0 + a1, t23p = a2 + a3;
238+
auto t01m = a0 - a1, t23m = (a2 - a3) * im;
239+
A[j + m4 * 0] = (t01p + t23p) * one; A[j + m4 * 2] = (t01p - t23p) * w1;
240+
A[j + m4 * 1] = (t01m + t23m) * w2; A[j + m4 * 3] = (t01m - t23m) * w3;
241+
}
242+
w2 *= dw[__builtin_ctz(~(i >> e))];
243+
}
244+
}
245+
if (logn & 1) for (int i = 0; i < nh; ++i) {
246+
ntt::sum_diff(A[i], A[i + nh]);
247+
}
248+
}
249+
250+
template <typename mod_t>
251+
void convolute(mod_t* A, int n, mod_t* B, int m, bool cyclic=false) {
252+
const int s = cyclic ? std::max(n, m) : n + m - 1;
253+
const int size = 1 << (31 - __builtin_clz(2 * s - 1));
254+
mod_t roots[mod_t::level], iroots[mod_t::level];
255+
roots[0] = mod_t::omega();
256+
for (int i = 1; i < mod_t::level; ++i) {
257+
roots[i] = roots[i - 1] * roots[i - 1];
258+
}
259+
iroots[0] = roots[0].inverse();
260+
for (int i = 1; i < mod_t::level; ++i) {
261+
iroots[i] = iroots[i - 1] * iroots[i - 1];
262+
}
263+
std::fill(A + n, A + size, 0); transform(A, size, roots, iroots);
264+
const auto inv = mod_t(size).inverse();
265+
if (A == B && n == m) {
266+
for (int i = 0; i < size; ++i) A[i] *= A[i] * inv;
267+
} else {
268+
std::fill(B + m, B + size, 0); transform(B, size, roots, iroots);
269+
for (int i = 0; i < size; ++i) A[i] *= B[i] * inv;
270+
}
271+
itransform(A, size, roots, iroots);
272+
}
273+
}
274+
275+
// 4405523190172876801, 19
276+
// 4481719345977753601, 11
277+
// 4601552919265804289, 3
278+
using mod_1 = ntt::Mod<uint64, uint128, int64, 709143768229478401, 31>;
279+
using mod_2 = ntt::Mod<uint64, uint128, int64, 711416664922521601, 19>;
280+
using mod_3 = ntt::Mod<uint64, uint128, int64, 1945555039024054273, 5>;
281+
using mod_4 = ntt::Mod<uint32, uint64, int, 2013265921, 31>;
282+
using mod_5 = ntt::Mod<uint32, uint64, int, 2113929217, 5>;

mathematics/number-theoretic-transform.cc

Lines changed: 0 additions & 35 deletions
This file was deleted.

0 commit comments

Comments
 (0)