Skip to content

Commit 202860b

Browse files
committed
fix a bug in fft::trans
1 parent 8f6c34b commit 202860b

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

mathematics/fast-fourier-transform.cc

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ namespace fft {
1515
Complex operator - (const Complex& r) const {
1616
return Complex(x - r.x , y - r.y);
1717
}
18+
Complex operator * (const double k) const {
19+
return Complex(x * k, y * k);
20+
}
21+
Complex operator / (const double k) const {
22+
return Complex(x / k, y / k);
23+
}
1824
Complex operator * (const Complex& r) const {
1925
return Complex(x * r.x - y * r.y , x * r.y + y * r.x);
2026
}
@@ -54,9 +60,14 @@ namespace fft {
5460
}
5561
}
5662
}
63+
if (oper == -1) {
64+
for (int i = 0; i < n; ++i) {
65+
P[i] = P[i] / n;
66+
}
67+
}
5768
}
5869
Complex A[N] , B[N] , C1[N] , C2[N];
59-
std::vector<int> conv(const std::vector<int> &a, const std::vector<int> &b) {
70+
std::vector<int64> conv(const std::vector<int> &a, const std::vector<int> &b) {
6071
int n = a.size(), m = b.size(), s = 1;
6172
while (s <= n + m - 2) s <<= 1;
6273
init(s);
@@ -69,8 +80,11 @@ namespace fft {
6980
for (int i = 0; i < s; ++i) {
7081
A[i] = A[i] * B[i];
7182
}
83+
for (int i = 0; i < s; ++i) {
84+
w[i] = w[i].conj();
85+
}
7286
trans(A, s, -1);
73-
std::vector<int> res(n + m - 1);
87+
std::vector<int64> res(n + m - 1);
7488
for (int i = 0; i < s; ++i) {
7589
res[i] = (int64)(A[i].x + 0.5);
7690
}
@@ -101,10 +115,10 @@ namespace fft {
101115
trans(C1, s, -1);
102116
trans(C2, s, -1);
103117
for (int i = 0 ; i < n + m - 1; ++i) {
104-
int x = (int64)(C1[i].x / s + 0.5) % mod;
105-
int y1 = (int64)(C1[i].y / s + 0.5) % mod;
106-
int y2 = (int64)(C2[i].x / s + 0.5) % mod;
107-
int z = (int64)(C2[i].y / s + 0.5) % mod;
118+
int x = int64(C1[i].x + 0.5) % mod;
119+
int y1 = int64(C1[i].y + 0.5) % mod;
120+
int y2 = int64(C2[i].x + 0.5) % mod;
121+
int z = int64(C2[i].y + 0.5) % mod;
108122
res[i] = ((int64)x * M * M + (int64)(y1 + y2) * M + z) % mod;
109123
}
110124
}

0 commit comments

Comments
 (0)