Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions formal_power_series/factorial_power.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
#include <algorithm>
#include <vector>

// CUT begin
// Convert factorial power -> sampling
// [y[0], y[1], ..., y[N - 1]] -> \sum_i a_i x^\underline{i}
// Complexity: O(N log N)
template <class T> std::vector<T> factorial_to_ys(const std::vector<T> &as) {
const int N = as.size();
std::vector<T> exp(N, 1);
for (int i = 1; i < N; i++) exp[i] = T(i).facinv();
for (int i = 1; i < N; i++) exp[i] = T::facinv(i);
auto ys = nttconv(as, exp);
ys.resize(N);
for (int i = 0; i < N; i++) ys[i] *= T(i).fac();
for (int i = 0; i < N; i++) ys[i] *= T::fac(i);
return ys;
}

Expand All @@ -22,9 +21,9 @@ template <class T> std::vector<T> factorial_to_ys(const std::vector<T> &as) {
// Complexity: O(N log N)
template <class T> std::vector<T> ys_to_factorial(std::vector<T> ys) {
const int N = ys.size();
for (int i = 1; i < N; i++) ys[i] *= T(i).facinv();
for (int i = 1; i < N; i++) ys[i] *= T::facinv(i);
std::vector<T> expinv(N, 1);
for (int i = 1; i < N; i++) expinv[i] = T(i).facinv() * (i % 2 ? -1 : 1);
for (int i = 1; i < N; i++) expinv[i] = T::facinv(i) * (i % 2 ? -1 : 1);
auto as = nttconv(ys, expinv);
as.resize(N);
return as;
Expand All @@ -36,12 +35,12 @@ template <class T> std::vector<T> shift_of_factorial(const std::vector<T> &as, T
const int N = as.size();
std::vector<T> b(N, 1), c(N, 1);
for (int i = 1; i < N; i++) b[i] = b[i - 1] * (shift - i + 1) * T(i).inv();
for (int i = 0; i < N; i++) c[i] = as[i] * T(i).fac();
for (int i = 0; i < N; i++) c[i] = as[i] * T::fac(i);
std::reverse(c.begin(), c.end());
auto ret = nttconv(b, c);
ret.resize(N);
std::reverse(ret.begin(), ret.end());
for (int i = 0; i < N; i++) ret[i] *= T(i).facinv();
for (int i = 0; i < N; i++) ret[i] *= T::facinv(i);
return ret;
}

Expand Down
4 changes: 2 additions & 2 deletions formal_power_series/formal_power_series.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,14 @@ template <typename T> struct FormalPowerSeries : std::vector<T> {
P shift(T c) const {
const int n = (int)this->size();
P ret = *this;
for (int i = 0; i < n; i++) ret[i] *= T(i).fac();
for (int i = 0; i < n; i++) ret[i] *= T::fac(i);
std::reverse(ret.begin(), ret.end());
P exp_cx(n, 1);
for (int i = 1; i < n; i++) exp_cx[i] = exp_cx[i - 1] * c * T(i).inv();
ret = ret * exp_cx;
ret.resize(n);
std::reverse(ret.begin(), ret.end());
for (int i = 0; i < n; i++) ret[i] *= T(i).facinv();
for (int i = 0; i < n; i++) ret[i] *= T::facinv(i);
return ret;
}

Expand Down
2 changes: 1 addition & 1 deletion formal_power_series/lagrange_interpolation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ template <typename MODINT> MODINT interpolate_iota(const std::vector<MODINT> ys,
const int N = ys.size();
if (x_eval.val() < N) return ys[x_eval.val()];
std::vector<MODINT> facinv(N);
facinv[N - 1] = MODINT(N - 1).fac().inv();
facinv[N - 1] = MODINT::facinv(N - 1);
for (int i = N - 1; i > 0; i--) facinv[i - 1] = facinv[i] * i;
std::vector<MODINT> numleft(N);
MODINT numtmp = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ MODINT sum_of_exponential_times_polynomial_limit(MODINT r, std::vector<MODINT> i
MODINT ret = 0;
rp = 1;
for (int i = 0; i <= d; i++) {
ret += bs[d - i] * MODINT(d + 1).nCr(i) * rp;
ret += bs[d - i] * MODINT::binom(d + 1, i) * rp;
rp *= -r;
}
return ret / MODINT(1 - r).pow(d + 1);
Expand Down
2 changes: 1 addition & 1 deletion formal_power_series/test/bernoulli_number.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ int main() {
using mint = ModInt<998244353>;
FormalPowerSeries<mint> x({0, 1});
FormalPowerSeries<mint> b = ((x.exp(N + 2) - 1) >> 1).inv(N + 1);
for (int i = 0; i <= N; i++) printf("%d ", (b.coeff(i) * mint(i).fac()).val());
for (int i = 0; i <= N; i++) printf("%d ", (b.coeff(i) * mint::fac(i)).val());
}
2 changes: 1 addition & 1 deletion formal_power_series/test/stirling_number_of_2nd.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ int main() {
cin >> N;
using mint = ModInt<998244353>;
FormalPowerSeries<mint> a(N + 1);
a[N] = mint(N).fac().inv();
a[N] = mint::facinv(N);
for (int i = N - 1; i >= 0; i--) { a[i] = a[i + 1] * (i + 1); }
auto b = a;
for (int i = 0; i <= N; i++) { a[i] *= mint(i).pow(N), b[i] *= (i % 2 ? -1 : 1); }
Expand Down
61 changes: 39 additions & 22 deletions modint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <vector>

template <int md> struct ModInt {
static_assert(md > 1);
using lint = long long;
constexpr static int mod() { return md; }
static int get_primitive_root() {
Expand Down Expand Up @@ -102,39 +103,50 @@ template <int md> struct ModInt {
return this->pow(md - 2);
}
}
constexpr ModInt fac() const {
while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2);
return facs[this->val_];

constexpr static ModInt fac(int n) {
assert(n >= 0);
if (n >= md) return ModInt(0);
while (n >= int(facs.size())) _precalculation(facs.size() * 2);
return facs[n];
}
constexpr ModInt facinv() const {
while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2);
return facinvs[this->val_];

constexpr static ModInt facinv(int n) {
assert(n >= 0);
if (n >= md) return ModInt(0);
while (n >= int(facs.size())) _precalculation(facs.size() * 2);
return facinvs[n];
}
constexpr ModInt doublefac() const {
lint k = (this->val_ + 1) / 2;
return (this->val_ & 1) ? ModInt(k * 2).fac() / (ModInt(2).pow(k) * ModInt(k).fac())
: ModInt(k).fac() * ModInt(2).pow(k);

constexpr static ModInt doublefac(int n) {
assert(n >= 0);
if (n >= md) return ModInt(0);
long long k = (n + 1) / 2;
return (n & 1) ? ModInt::fac(k * 2) / (ModInt(2).pow(k) * ModInt::fac(k))
: ModInt::fac(k) * ModInt(2).pow(k);
}

constexpr ModInt nCr(int r) const {
if (r < 0 or this->val_ < r) return ModInt(0);
return this->fac() * (*this - r).facinv() * ModInt(r).facinv();
constexpr static ModInt nCr(int n, int r) {
assert(n >= 0);
if (r < 0 or n < r) return ModInt(0);
return ModInt::fac(n) * ModInt::facinv(r) * ModInt::facinv(n - r);
}

constexpr ModInt nPr(int r) const {
if (r < 0 or this->val_ < r) return ModInt(0);
return this->fac() * (*this - r).facinv();
constexpr static ModInt nPr(int n, int r) {
assert(n >= 0);
if (r < 0 or n < r) return ModInt(0);
return ModInt::fac(n) * ModInt::facinv(n - r);
}

static ModInt binom(int n, int r) {
static long long bruteforce_times = 0;

if (r < 0 or n < r) return ModInt(0);
if (n <= bruteforce_times or n < (int)facs.size()) return ModInt(n).nCr(r);
if (n <= bruteforce_times or n < (int)facs.size()) return ModInt::nCr(n, r);

r = std::min(r, n - r);

ModInt ret = ModInt(r).facinv();
ModInt ret = ModInt::facinv(r);
for (int i = 0; i < r; ++i) ret *= n - i;
bruteforce_times += r;

Expand All @@ -148,18 +160,23 @@ template <int md> struct ModInt {
int sum = 0;
for (int k : ks) {
assert(k >= 0);
ret *= ModInt(k).facinv(), sum += k;
ret *= ModInt::facinv(k), sum += k;
}
return ret * ModInt(sum).fac();
return ret * ModInt::fac(sum);
}
template <class... Args> static ModInt multinomial(Args... args) {
int sum = (0 + ... + args);
ModInt result = (1 * ... * ModInt::facinv(args));
return ModInt::fac(sum) * result;
}

// Catalan number, C_n = binom(2n, n) / (n + 1)
// Catalan number, C_n = binom(2n, n) / (n + 1) = # of Dyck words of length 2n
// C_0 = 1, C_1 = 1, C_2 = 2, C_3 = 5, C_4 = 14, ...
// https://oeis.org/A000108
// Complexity: O(n)
static ModInt catalan(int n) {
if (n < 0) return ModInt(0);
return ModInt(n * 2).fac() * ModInt(n + 1).facinv() * ModInt(n).facinv();
return ModInt::fac(n * 2) * ModInt::facinv(n + 1) * ModInt::facinv(n);
}

ModInt sqrt() const {
Expand Down
41 changes: 27 additions & 14 deletions number/modint_runtime.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include <cassert>
#include <iostream>
#include <set>
#include <vector>
Expand Down Expand Up @@ -105,26 +106,38 @@ struct ModIntRuntime {
ModIntRuntime pow(lint n) const { return power(n); }
ModIntRuntime inv() const { return this->pow(md - 2); }

ModIntRuntime fac() const {
static ModIntRuntime fac(int n) {
assert(n >= 0);
if (n >= md) return ModIntRuntime(0);
int l0 = facs().size();
if (l0 > this->val_) return facs()[this->val_];

facs().resize(this->val_ + 1);
for (int i = l0; i <= this->val_; i++)
if (l0 > n) return facs()[n];
facs().resize(n + 1);
for (int i = l0; i <= n; i++)
facs()[i] = (i == 0 ? ModIntRuntime(1) : facs()[i - 1] * ModIntRuntime(i));
return facs()[this->val_];
return facs()[n];
}

static ModIntRuntime facinv(int n) { return ModIntRuntime::fac(n).inv(); }

static ModIntRuntime doublefac(int n) {
assert(n >= 0);
if (n >= md) return ModIntRuntime(0);
long long k = (n + 1) / 2;
return (n & 1)
? ModIntRuntime::fac(k * 2) / (ModIntRuntime(2).pow(k) * ModIntRuntime::fac(k))
: ModIntRuntime::fac(k) * ModIntRuntime(2).pow(k);
}

ModIntRuntime doublefac() const {
lint k = (this->val_ + 1) / 2;
return (this->val_ & 1)
? ModIntRuntime(k * 2).fac() / (ModIntRuntime(2).pow(k) * ModIntRuntime(k).fac())
: ModIntRuntime(k).fac() * ModIntRuntime(2).pow(k);
static ModIntRuntime nCr(int n, int r) {
assert(n >= 0);
if (r < 0 or n < r) return ModIntRuntime(0);
return ModIntRuntime::fac(n) / (ModIntRuntime::fac(r) * ModIntRuntime::fac(n - r));
}

ModIntRuntime nCr(int r) const {
if (r < 0 or this->val_ < r) return ModIntRuntime(0);
return this->fac() / ((*this - r).fac() * ModIntRuntime(r).fac());
static ModIntRuntime nPr(int n, int r) {
assert(n >= 0);
if (r < 0 or n < r) return ModIntRuntime(0);
return ModIntRuntime::fac(n) / ModIntRuntime::fac(n - r);
}

ModIntRuntime sqrt() const {
Expand Down