Skip to content

Commit 969fc8e

Browse files
committed
Use basic_string/vector + SIMD in linalg
1 parent 5147e7a commit 969fc8e

17 files changed

+150
-119
lines changed

cp-algo/linalg/frobenius.hpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,30 @@
44
#include "matrix.hpp"
55
#include <algorithm>
66
#include <vector>
7+
#include <ranges>
78
namespace cp_algo::linalg {
89
enum frobenius_mode {blocks, full};
910
template<frobenius_mode mode = blocks>
1011
auto frobenius_form(auto const& A) {
1112
using matrix = std::decay_t<decltype(A)>;
13+
using vec_t = matrix::vec_t;
14+
using base = typename matrix::base;
1215
using base = matrix::base;
1316
using polyn = math::poly_t<base>;
1417
assert(A.n() == A.m());
1518
size_t n = A.n();
1619
std::vector<polyn> charps;
17-
std::vector<vec<base>> basis, basis_init;
20+
std::vector<vec_t> basis, basis_init;
1821
while(size(basis) < n) {
1922
size_t start = size(basis);
2023
auto generate_block = [&](auto x) {
2124
while(true) {
22-
vec<base> y = x | vec<base>::ei(n + 1, size(basis));
25+
vec_t y = x | vec_t::ei(n + 1, size(basis));
2326
for(auto &it: basis) {
2427
y.reduce_by(it);
2528
}
2629
y.normalize();
27-
if(vec<base>(y[std::slice(0, n, 1)]) == vec<base>(n)) {
30+
if(std::ranges::count(y | std::views::take(n), base(0)) == int(n)) {
2831
return polyn(typename polyn::Vector(begin(y) + n, end(y)));
2932
} else {
3033
basis_init.push_back(x);
@@ -33,7 +36,7 @@ namespace cp_algo::linalg {
3336
}
3437
}
3538
};
36-
auto full_rec = generate_block(vec<base>::random(n));
39+
auto full_rec = generate_block(vec_t::random(n));
3740
// Extra trimming to make it block-diagonal (expensive)
3841
if constexpr (mode == full) {
3942
if(full_rec.mod_xk(start) != polyn()) {
@@ -58,12 +61,12 @@ namespace cp_algo::linalg {
5861
}
5962
basis[i].normalize();
6063
}
61-
auto T = matrix::from_range(basis_init);
62-
auto Tinv = matrix::from_range(basis);
64+
auto T = matrix(basis_init);
65+
auto Tinv = matrix(basis);
6366
std::ignore = Tinv.sort_classify(n);
6467
for(size_t i = 0; i < n; i++) {
65-
Tinv[i] = vec<base>(
66-
Tinv[i][std::slice(n, n, 1)]
68+
Tinv[i] = vec_t(
69+
Tinv[i] | std::views::drop(n) | std::views::take(n)
6770
) * (base(1) / Tinv[i][i]);
6871
}
6972
return std::tuple{T, Tinv, charps};

cp-algo/linalg/matrix.hpp

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,26 @@
1010
#include <array>
1111
namespace cp_algo::linalg {
1212
enum gauss_mode {normal, reverse};
13-
template<typename base_t>
14-
struct matrix: valarray_base<matrix<base_t>, vec<base_t>> {
13+
14+
template<typename base_t, class _vec_t = std::conditional_t<
15+
math::modint_type<base_t>,
16+
modint_vec<base_t>,
17+
vec<base_t>>>
18+
struct matrix: std::vector<_vec_t> {
19+
using vec_t = _vec_t;
1520
using base = base_t;
16-
using Base = valarray_base<matrix<base>, vec<base>>;
21+
using Base = std::vector<vec_t>;
1722
using Base::Base;
1823

19-
matrix(size_t n): Base(vec<base>(n), n) {}
20-
matrix(size_t n, size_t m): Base(vec<base>(m), n) {}
24+
matrix(size_t n): Base(n, vec_t(n)) {}
25+
matrix(size_t n, size_t m): Base(n, vec_t(m)) {}
26+
27+
matrix(Base const& t): Base(t) {}
28+
matrix(Base &&t): Base(std::move(t)) {}
29+
30+
static matrix from(auto &&r) {
31+
return std::ranges::to<Base>(r);
32+
}
2133

2234
size_t n() const {return size(*this);}
2335
size_t m() const {return n() ? size(row(0)) : 0;}
@@ -26,6 +38,10 @@ namespace cp_algo::linalg {
2638
auto& row(size_t i) {return (*this)[i];}
2739
auto const& row(size_t i) const {return (*this)[i];}
2840

41+
42+
auto operator-() const {
43+
return from(*this | std::views::transform([](auto x) {return vec_t(-x);}));
44+
}
2945
matrix& operator *=(base t) {for(auto &it: *this) it *= t; return *this;}
3046
matrix operator *(base t) const {return matrix(*this) *= t;}
3147
matrix& operator /=(base t) {return *this *= base(1) / t;}
@@ -34,6 +50,13 @@ namespace cp_algo::linalg {
3450
// Make sure the result is matrix, not Base
3551
matrix& operator *=(matrix const& t) {return *this = *this * t;}
3652

53+
void read_transposed() {
54+
for(size_t j = 0; j < m(); j++) {
55+
for(size_t i = 0; i < n(); i++) {
56+
std::cin >> (*this)[i][j];
57+
}
58+
}
59+
}
3760
void read() {
3861
for(auto &it: *this) {
3962
it.read();
@@ -55,15 +78,15 @@ namespace cp_algo::linalg {
5578
n = 0;
5679
for(auto &it: blocks) {
5780
for(size_t i = 0; i < it.n(); i++) {
58-
res[n + i][std::slice(n, it.n(), 1)] = it[i];
81+
std::ranges::copy(it[i], begin(res[n + i]) + n);
5982
}
6083
n += it.n();
6184
}
6285
return res;
6386
}
6487
static matrix random(size_t n, size_t m) {
6588
matrix res(n, m);
66-
std::ranges::generate(res, std::bind(vec<base>::random, m));
89+
std::ranges::generate(res, std::bind(vec_t::random, m));
6790
return res;
6891
}
6992
static matrix random(size_t n) {
@@ -86,12 +109,9 @@ namespace cp_algo::linalg {
86109
}
87110
return res;
88111
}
89-
matrix submatrix(auto slicex, auto slicey) const {
90-
matrix res = (*this)[slicex];
91-
for(auto &row: res) {
92-
row = vec<base>(row[slicey]);
93-
}
94-
return res;
112+
matrix submatrix(auto viewx, auto viewy) const {
113+
return from(*this | viewx | std::views::transform(
114+
[&](auto const& y) {return vec_t(y | viewy);}));
95115
}
96116

97117
matrix T() const {
@@ -115,8 +135,8 @@ namespace cp_algo::linalg {
115135
return res.normalize();
116136
}
117137

118-
vec<base> apply(vec<base> const& x) const {
119-
return (matrix(x) * *this)[0];
138+
vec_t apply(vec_t const& x) const {
139+
return (matrix(1, x) * *this)[0];
120140
}
121141

122142
matrix pow(uint64_t k) const {
@@ -193,7 +213,7 @@ namespace cp_algo::linalg {
193213
det *= b[i][i];
194214
b[i] *= base(1) / b[i][i];
195215
}
196-
return {det, b.submatrix(std::slice(0, n(), 1), std::slice(n(), n(), 1))};
216+
return {det, b.submatrix(std::views::take(n()), std::views::drop(n()) | std::views::take(n()))};
197217
}
198218

199219
// Can also just run gauss on T() | eye(m)
@@ -218,16 +238,16 @@ namespace cp_algo::linalg {
218238
std::optional<std::array<matrix, 2>> solve(matrix t) const {
219239
matrix sols = (*this | t).kernel();
220240
if(sols.n() < t.m() || sols.submatrix(
221-
std::slice(sols.n() - t.m(), t.m(), 1),
222-
std::slice(m(), t.m(), 1)
241+
std::views::drop(sols.n() - t.m()),
242+
std::views::drop(m())
223243
) != -eye(t.m())) {
224244
return std::nullopt;
225245
} else {
226246
return std::array{
227-
sols.submatrix(std::slice(sols.n() - t.m(), t.m(), 1),
228-
std::slice(0, m(), 1)),
229-
sols.submatrix(std::slice(0, sols.n() - t.m(), 1),
230-
std::slice(0, m(), 1))
247+
sols.submatrix(std::views::drop(sols.n() - t.m()),
248+
std::views::take(m())),
249+
sols.submatrix(std::views::take(sols.n() - t.m()),
250+
std::views::take(m()))
231251
};
232252
}
233253
}

cp-algo/linalg/vector.hpp

Lines changed: 53 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,59 +2,46 @@
22
#define CP_ALGO_LINALG_VECTOR_HPP
33
#include "../random/rng.hpp"
44
#include "../number_theory/modint.hpp"
5+
#include "../util/big_alloc.hpp"
6+
#include "../util/simd.hpp"
7+
#include "../util/checkpoint.hpp"
58
#include <functional>
69
#include <algorithm>
710
#include <valarray>
811
#include <iostream>
912
#include <iterator>
1013
#include <cassert>
14+
#include <ranges>
1115
namespace cp_algo::linalg {
12-
template<class vec, typename base>
13-
struct valarray_base: std::valarray<base> {
14-
using Base = std::valarray<base>;
16+
template<typename base, class Alloc = big_alloc<base>>
17+
struct vec: std::basic_string<base, std::char_traits<base>, Alloc> {
18+
using Base = std::basic_string<base, std::char_traits<base>, Alloc>;
1519
using Base::Base;
1620

17-
valarray_base(base const& t): Base(t, 1) {}
18-
19-
auto begin() {return std::begin(to_valarray());}
20-
auto begin() const {return std::begin(to_valarray());}
21-
auto end() {return std::end(to_valarray());}
22-
auto end() const {return std::end(to_valarray());}
23-
24-
bool operator == (vec const& t) const {return std::ranges::equal(*this, t);}
25-
bool operator != (vec const& t) const {return !(*this == t);}
26-
27-
vec operator-() const {return Base::operator-();}
28-
29-
static vec from_range(auto const& R) {
30-
vec res(std::ranges::distance(R));
31-
std::ranges::copy(R, res.begin());
32-
return res;
33-
}
34-
Base& to_valarray() {return static_cast<Base&>(*this);}
35-
Base const& to_valarray() const {return static_cast<Base const&>(*this);}
36-
};
37-
38-
template<class vec, typename base>
39-
vec operator+(valarray_base<vec, base> const& a, valarray_base<vec, base> const& b) {
40-
return a.to_valarray() + b.to_valarray();
41-
}
42-
template<class vec, typename base>
43-
vec operator-(valarray_base<vec, base> const& a, valarray_base<vec, base> const& b) {
44-
return a.to_valarray() - b.to_valarray();
45-
}
46-
47-
template<class vec, typename base>
48-
struct vec_base: valarray_base<vec, base> {
49-
using Base = valarray_base<vec, base>;
50-
using Base::Base;
21+
vec(Base const& t): Base(t) {}
22+
vec(Base &&t): Base(std::move(t)) {}
23+
vec(size_t n): Base(n, base()) {}
24+
vec(auto &&r): Base(std::ranges::to<Base>(r)) {}
5125

5226
static vec ei(size_t n, size_t i) {
5327
vec res(n);
5428
res[i] = 1;
5529
return res;
5630
}
5731

32+
auto operator-() const {
33+
return *this | std::views::transform([](auto x) {return -x;});
34+
}
35+
auto operator *(base t) const {
36+
return *this | std::views::transform([t](auto x) {return x * t;});
37+
}
38+
auto operator *=(base t) {
39+
for(auto &it: *this) {
40+
it *= t;
41+
}
42+
return *this;
43+
}
44+
5845
virtual void add_scaled(vec const& b, base scale, size_t i = 0) {
5946
if(scale != base(0)) {
6047
for(; i < size(*this); i++) {
@@ -74,7 +61,9 @@ namespace cp_algo::linalg {
7461
}
7562
}
7663
void print() const {
77-
std::ranges::copy(*this, std::ostream_iterator<base>(std::cout, " "));
64+
for(auto &it: *this) {
65+
std::cout << it << " ";
66+
}
7867
std::cout << "\n";
7968
}
8069
static vec random(size_t n) {
@@ -84,10 +73,10 @@ namespace cp_algo::linalg {
8473
}
8574
// Concatenate vectors
8675
vec operator |(vec const& t) const {
87-
vec res(size(*this) + size(t));
88-
res[std::slice(0, size(*this), 1)] = *this;
89-
res[std::slice(size(*this), size(t), 1)] = t;
90-
return res;
76+
return std::views::join(std::array{
77+
std::views::all(*this),
78+
std::views::all(t)
79+
});
9180
}
9281

9382
// Generally, vec shouldn't be modified
@@ -115,23 +104,32 @@ namespace cp_algo::linalg {
115104
base pivot_inv;
116105
};
117106

118-
template<typename base>
119-
struct vec: vec_base<vec<base>, base> {
120-
using Base = vec_base<vec<base>, base>;
107+
template<math::modint_type base, class Alloc = big_alloc<base>>
108+
struct modint_vec: vec<base, Alloc> {
109+
using Base = vec<base, Alloc>;
121110
using Base::Base;
122-
};
123111

124-
template<math::modint_type base>
125-
struct vec<base>: vec_base<vec<base>, base> {
126-
using Base = vec_base<vec<base>, base>;
127-
using Base::Base;
112+
modint_vec(Base const& t): Base(t) {}
113+
modint_vec(Base &&t): Base(std::move(t)) {}
128114

129-
void add_scaled(vec const& b, base scale, size_t i = 0) override {
115+
void add_scaled(Base const& b, base scale, size_t i = 0) override {
130116
static_assert(base::bits >= 64, "Only wide modint types for linalg");
131-
uint64_t scaler = scale.getr();
132117
if(scale != base(0)) {
133-
for(; i < size(*this); i++) {
134-
(*this)[i].add_unsafe(scaler * b[i].getr_direct());
118+
assert(Base::size() == b.size());
119+
size_t n = size(*this);
120+
u64x4 scaler = u64x4() + scale.getr();
121+
if (is_aligned(this) && is_aligned(&b[0])) // verify we're not in SSO
122+
for(i -= i % 4; i < n - 3; i += 4) {
123+
auto &ai = vector_cast<u64x4>((*this)[i]);
124+
auto bi = vector_cast<u64x4 const>(b[i]);
125+
#ifdef __AVX2__
126+
ai += u64x4(_mm256_mul_epu32(__m256i(scaler), __m256i(bi)));
127+
#else
128+
ai += scaler * bi;
129+
#endif
130+
}
131+
for(; i < n; i++) {
132+
(*this)[i].add_unsafe(b[i].getr_direct() * scale.getr());
135133
}
136134
if(++counter == 4) {
137135
for(auto &it: *this) {
@@ -141,7 +139,7 @@ namespace cp_algo::linalg {
141139
}
142140
}
143141
}
144-
vec const& normalize() override {
142+
Base const& normalize() override {
145143
for(auto &it: *this) {
146144
it.normalize();
147145
}

cp-algo/number_theory/modint.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace cp_algo::math {
2121
static UInt2 modmod() {
2222
return UInt2(mod()) * mod();
2323
}
24-
modint_base(): r(0) {}
24+
modint_base() = default;
2525
modint_base(Int2 rr) {
2626
to_modint().setr(UInt((rr + modmod()) % mod()));
2727
}

cp-algo/util/simd.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ namespace cp_algo {
6060
}
6161

6262
template<std::size_t Align = 32>
63-
constexpr std::size_t aligned_idx(auto const& c, std::size_t i = 0) {
64-
auto const* p = std::data(c) + i;
65-
using value_type = std::remove_pointer_t<decltype(p)>;
66-
constexpr auto mask = Align - 1;
67-
std::uintptr_t addr = reinterpret_cast<std::uintptr_t>(p);
68-
std::size_t bytes_to_next = (-addr) & mask;
69-
return i + bytes_to_next / sizeof(value_type);
63+
[[gnu::always_inline]] inline bool is_aligned(const void* p) noexcept {
64+
return (reinterpret_cast<std::uintptr_t>(p) % Align) == 0;
65+
}
66+
67+
template<class Target>
68+
[[gnu::always_inline]] inline Target& vector_cast(auto &&p) {
69+
return *reinterpret_cast<Target*>(std::assume_aligned<alignof(Target)>(&p));
7070
}
7171
}
7272
#endif // CP_ALGO_UTIL_SIMD_HPP

verify/linalg/adj.test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
// @brief Adjugate Matrix
2-
// competitive-verifier: PROBLEM https://judge.yosupo.jp/problem/adjugate_matrix
2+
#define PROBLEM "https://judge.yosupo.jp/problem/adjugate_matrix"
33
#pragma GCC optimize("Ofast,unroll-loops")
4-
#include "cp-algo/linalg/matrix.hpp"
54
#include <bits/stdc++.h>
5+
#include "blazingio/blazingio.min.hpp"
6+
#include "cp-algo/linalg/matrix.hpp"
67

78
const int64_t mod = 998244353;
89

0 commit comments

Comments
 (0)