Skip to content

Commit b98f1f7

Browse files
committed
Fixed corner cases
1 parent 35fe2ed commit b98f1f7

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

math/inverse.cpp

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
#include <bits/stdc++.h>
1+
#include <iostream>
2+
#include <vector>
3+
4+
#include <gmpxx.h>
25

36
template<typename T>
47
void equalize_row(std::vector<T> const& A, std::vector<T> &B, T ratio) {
@@ -8,6 +11,13 @@ void equalize_row(std::vector<T> const& A, std::vector<T> &B, T ratio) {
811
}
912
}
1013

14+
template<typename T>
15+
void divide_row(std::vector<T> &A, T ratio) {
16+
for (T& num : A) {
17+
num /= ratio;
18+
}
19+
}
20+
1121
template<typename T>
1222
void print_matrix(std::vector<T> const& A) {
1323
for (auto& row : A) {
@@ -27,48 +37,52 @@ std::vector<std::vector<T>> inverse(std::vector<std::vector<T>> M) {
2737
R[i][i] = T(1);
2838
}
2939

30-
for (int32_t i = 0; i+1 < n; i++) {
31-
T cut_base = M[i][i];
40+
for (int32_t i = 0; i < n; i++) {
41+
int32_t best = i;
3242
for (int32_t j = i+1; j < n; j++) {
33-
T ratio = - M[j][i] / cut_base;
34-
equalize_row(M[i], M[j], ratio);
35-
equalize_row(R[i], R[j], ratio);
43+
if (abs(M[best][i]) < abs(M[j][i])) {
44+
best = j;
45+
}
46+
}
47+
if (best != i) {
48+
std::swap(M[i], M[best]);
49+
std::swap(R[i], R[best]);
50+
}
51+
52+
divide_row<T>(R[i], M[i][i]);
53+
divide_row<T>(M[i], M[i][i]);
54+
55+
for (int32_t j = i+1; j < n; j++) {
56+
T ratio = -M[j][i];
57+
equalize_row<T>(M[i], M[j], ratio);
58+
equalize_row<T>(R[i], R[j], ratio);
3659
}
3760
}
3861

3962
for (int32_t i = n-1; i > 0; i--) {
40-
T cut_base = M[i][i];
4163
for (int32_t j = i-1; j >= 0; j--) {
42-
T ratio = - M[j][i] / cut_base;
43-
equalize_row(M[i], M[j], ratio);
44-
equalize_row(R[i], R[j], ratio);
64+
equalize_row<T>(R[i], R[j], -M[j][i]);
65+
equalize_row<T>(M[i], M[j], -M[j][i]);
4566
}
4667
}
4768

48-
print_matrix(M);
49-
for (int32_t i = 0; i < n; i++) {
50-
if (M[i][i] != T(1)) {
51-
T ratio = T(1) / M[i][i] - T(1);
52-
equalize_row(R[i], R[i], ratio);
53-
equalize_row(M[i], M[i], ratio);
54-
}
55-
}
56-
print_matrix(M);
5769
return R;
5870
}
5971

6072
int main() {
73+
using rational = mpq_class;
74+
6175
int32_t n;
6276
std::cin >> n;
6377

64-
auto M = std::vector<std::vector<double>>(n, std::vector<double>(n));
78+
auto M = std::vector<std::vector<rational>>(n, std::vector<rational>(n));
6579
for (auto& row : M) {
6680
for (auto& x : row) {
6781
std::cin >> x;
6882
}
6983
}
7084

71-
auto ans = inverse(M);
85+
auto A = inverse<rational>(M);
7286

73-
print_matrix(ans);
87+
print_matrix(A);
7488
}

0 commit comments

Comments
 (0)