Skip to content

Commit ece2fed

Browse files
committed
many factorials with SIMD
1 parent 7567e13 commit ece2fed

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

verify/simd/many_facts.test.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// @brief Many Factorials
2+
#define PROBLEM "https://judge.yosupo.jp/problem/many_factorials"
3+
#pragma GCC optimize("Ofast,unroll-loops")
4+
#include <bits/stdc++.h>
5+
#include "blazingio/blazingio.min.hpp"
6+
#include "cp-algo/util/simd.hpp"
7+
#include "cp-algo/math/common.hpp"
8+
9+
using namespace std;
10+
using namespace cp_algo;
11+
12+
constexpr int mod = 998244353;
13+
constexpr auto mod4 = u64x4() + mod;
14+
constexpr auto imod4 = u64x4() - math::inv2(mod);
15+
16+
void facts_inplace(vector<int> &args) {
17+
constexpr int block = 1 << 16;
18+
static basic_string<size_t> args_per_block[mod / block];
19+
uint64_t limit = 0;
20+
for(auto [i, x]: args | views::enumerate) {
21+
if(x < mod / 2) {
22+
limit = max(limit, uint64_t(x));
23+
args_per_block[x / block].push_back(i);
24+
} else {
25+
limit = max(limit, uint64_t(mod - x - 1));
26+
args_per_block[(mod - x - 1) / block].push_back(i);
27+
}
28+
}
29+
uint64_t b2x32 = (1ULL << 32) % mod;
30+
uint64_t fact = 1;
31+
for(uint64_t b = 0; b <= limit; b += block) {
32+
u64x4 cur = {b, b + block / 4, b + block / 2, b + 3 * block / 4};
33+
static array<u64x4, block / 4> prods;
34+
prods[0] = u64x4{cur[0] + !b, cur[1], cur[2], cur[3]};
35+
cur = cur * b2x32 % mod;
36+
for(int i = 1; i < block / 4; i++) {
37+
cur += b2x32;
38+
cur = cur >= mod ? cur - mod : cur;
39+
prods[i] = montgomery_mul(prods[i - 1], cur, mod4, imod4);
40+
}
41+
for(auto i: args_per_block[b / block]) {
42+
size_t x = args[i];
43+
if(x >= mod / 2) {
44+
x = mod - x - 1;
45+
}
46+
x -= b;
47+
auto pre_blocks = x / (block / 4);
48+
auto in_block = x % (block / 4);
49+
auto ans = fact * prods[in_block][pre_blocks] % mod;
50+
for(size_t z = 0; z < pre_blocks; z++) {
51+
ans = ans * prods.back()[z] % mod;
52+
}
53+
if(args[i] >= mod / 2) {
54+
ans = math::bpow(ans, mod - 2, 1ULL, [](auto a, auto b){return a * b % mod;});
55+
args[i] = int(x % 2 ? ans : mod - ans);
56+
} else {
57+
args[i] = int(ans);
58+
}
59+
}
60+
args_per_block[b / block].clear();
61+
for(int z = 0; z < 4; z++) {
62+
fact = fact * prods.back()[z] % mod;
63+
}
64+
}
65+
}
66+
67+
void solve() {
68+
int n;
69+
cin >> n;
70+
vector<int> args(n);
71+
for(auto &x : args) {cin >> x;}
72+
facts_inplace(args);
73+
for(auto it: args) {cout << it << "\n";}
74+
}
75+
76+
signed main() {
77+
//freopen("input.txt", "r", stdin);
78+
ios::sync_with_stdio(0);
79+
cin.tie(0);
80+
int t = 1;
81+
//cin >> t;
82+
while(t--) {
83+
solve();
84+
}
85+
}

0 commit comments

Comments
 (0)