Skip to content

Commit f273d92

Browse files
committed
fixes
1 parent c1d093f commit f273d92

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

cp-algo/math/cvector.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ namespace cp_algo::math::fft {
1515
using vftype [[gnu::vector_size(bytes)]] = ftype;
1616
using vpoint = complex<vftype>;
1717
static constexpr vftype vz = {};
18-
static constexpr vpoint vi = {vz, vz + 1};
18+
vpoint vi(vpoint const& r) {
19+
return {-imag(r), real(r)};
20+
}
1921
vftype abs(vftype a) {
2022
return a < 0 ? -a : a;
2123
}
@@ -132,8 +134,8 @@ namespace cp_algo::math::fft {
132134
auto D = at(j + 3 * i);
133135
at(j) = (A + B + C + D);
134136
at(j + 2 * i) = (A + B - C - D) * v2;
135-
at(j + i) = (A - B - vi * (C - D)) * v1;
136-
at(j + 3 * i) = (A - B + vi * (C - D)) * v3;
137+
at(j + i) = (A - B - vi(C - D)) * v1;
138+
at(j + 3 * i) = (A - B + vi(C - D)) * v3;
137139
}
138140
});
139141
i *= 2;
@@ -171,8 +173,8 @@ namespace cp_algo::math::fft {
171173
auto D = at(j + 3 * i) * v3;
172174
at(j) = (A + C) + (B + D);
173175
at(j + i) = (A + C) - (B + D);
174-
at(j + 2 * i) = (A - C) + vi * (B - D);
175-
at(j + 3 * i) = (A - C) - vi * (B - D);
176+
at(j + 2 * i) = (A - C) + vi(B - D);
177+
at(j + 3 * i) = (A - C) - vi(B - D);
176178
}
177179
});
178180
} else { // radix-2 fallback

cp-algo/math/fft.hpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ namespace cp_algo::math::fft {
8282
}
8383

8484
void recover_mod(auto &&C, auto &res, size_t k) {
85-
assert(size(res) % flen == 0);
85+
res.assign((k / flen + 1) * flen, base(0));
8686
size_t n = A.size();
8787
auto splitsplit = base(split * split).getr();
8888
base b2x32 = bpow(base(2), 32);
@@ -106,7 +106,7 @@ namespace cp_algo::math::fft {
106106
Au = montgomery_mul(Au, mul, mod, imod);
107107
Au = Au >= base::mod() ? Au - base::mod() : Au;
108108
for(size_t j = 0; j < flen; j++) {
109-
res[i + j].setr(Au[j]);
109+
res[i + j].setr(typename base::UInt(Au[j]));
110110
}
111111
};
112112
set_i(i, Ax, Bx, Cx, cur);
@@ -115,6 +115,7 @@ namespace cp_algo::math::fft {
115115
}
116116
cur = montgomery_mul(cur, step4, mod, imod);
117117
}
118+
res.resize(k);
118119
checkpoint("recover mod");
119120
}
120121

@@ -138,13 +139,13 @@ namespace cp_algo::math::fft {
138139
mul(cvector(B.A), B.B, res, k);
139140
}
140141
std::vector<base, big_alloc<base>> operator *= (dft &B) {
141-
std::vector<base, big_alloc<base>> res(2 * A.size());
142-
mul_inplace(B, res, size(res));
142+
std::vector<base, big_alloc<base>> res;
143+
mul_inplace(B, res, 2 * A.size());
143144
return res;
144145
}
145146
std::vector<base, big_alloc<base>> operator *= (dft const& B) {
146-
std::vector<base, big_alloc<base>> res(2 * A.size());
147-
mul(B, res, size(res));
147+
std::vector<base, big_alloc<base>> res;
148+
mul(B, res, 2 * A.size());
148149
return res;
149150
}
150151
auto operator * (dft const& B) const {
@@ -191,13 +192,11 @@ namespace cp_algo::math::fft {
191192
std::min(k, size(a)) + std::min(k, size(b)) - 1
192193
) / 2);
193194
auto A = dft<base>(a | std::views::take(k), n);
194-
a.assign((k / flen + 1) * flen, 0);
195195
if(&a == &b) {
196196
A.mul(A, a, k);
197197
} else {
198198
A.mul_inplace(dft<base>(b | std::views::take(k), n), a, k);
199199
}
200-
a.resize(k);
201200
}
202201
void mul(auto &a, auto const& b) {
203202
size_t N = size(a) + size(b) - 1;

cp-algo/math/poly/impl/div.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ namespace cp_algo::math::poly::impl {
8787
auto q1f = fft::dft<base>(q1.a, N);
8888
auto qqf = fft::dft<base>(qq.a, N);
8989
size_t M = q0.deg() + (n + 1) / 2;
90-
typename poly::Vector A(M), B(M);
90+
typename poly::Vector A, B;
9191
q0f.mul(qqf, A, M);
9292
q1f.mul_inplace(qqf, B, M);
9393
q.a.resize(n + 1);
@@ -120,7 +120,7 @@ namespace cp_algo::math::poly::impl {
120120
inv_inplace(qq, (n + 1) / 2);
121121
auto qqf = fft::dft<base>(qq.a, N);
122122

123-
typename poly::Vector A((n + 1) / 2), B((n + 1) / 2);
123+
typename poly::Vector A, B;
124124
q0f.mul(qqf, A, (n + 1) / 2);
125125
q1f.mul_inplace(qqf, B, (n + 1) / 2);
126126
p.a.resize(n + 1);

0 commit comments

Comments
 (0)