-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
Copy pathse3.py
223 lines (178 loc) · 7.81 KB
/
se3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import torch
from .so3 import _so3_exp_map, hat, so3_log_map
def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
"""
Convert a batch of logarithmic representations of SE(3) matrices `log_transform`
to a batch of 4x4 SE(3) matrices using the exponential map.
See e.g. [1], Sec 9.4.2. for more detailed description.
A SE(3) matrix has the following form:
```
[ R 0 ]
[ T 1 ] ,
```
where `R` is a 3x3 rotation matrix and `T` is a 3-D translation vector.
SE(3) matrices are commonly used to represent rigid motions or camera extrinsics.
In the SE(3) logarithmic representation SE(3) matrices are
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
The conversion from the 6D representation to a 4x4 SE(3) matrix `transform`
is done as follows:
```
transform = exp( [ hat(log_rotation) 0 ]
[ log_translation 1 ] ) ,
```
where `exp` is the matrix exponential and `hat` is the Hat operator [2].
Note that for any `log_transform` with `0 <= ||log_rotation|| < 2pi`
(i.e. the rotation angle is between 0 and 2pi), the following identity holds:
```
se3_log_map(se3_exponential_map(log_transform)) == log_transform
```
The conversion has a singularity around `||log(transform)|| = 0`
which is handled by clamping controlled with the `eps` argument.
Args:
log_transform: Batch of vectors of shape `(minibatch, 6)`.
eps: A threshold for clipping the squared norm of the rotation logarithm
to avoid unstable gradients in the singular case.
Returns:
Batch of transformation matrices of shape `(minibatch, 4, 4)`.
Raises:
ValueError if `log_transform` is of incorrect shape.
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
[2] https://en.wikipedia.org/wiki/Hat_operator
"""
if log_transform.ndim != 2 or log_transform.shape[1] != 6:
raise ValueError("Expected input to be of shape (N, 6).")
N, _ = log_transform.shape
log_translation = log_transform[..., :3]
log_rotation = log_transform[..., 3:]
# rotation is an exponential map of log_rotation
(
R,
rotation_angles,
log_rotation_hat,
log_rotation_hat_square,
) = _so3_exp_map(log_rotation, eps=eps)
# translation is V @ T
V = _se3_V_matrix(
log_rotation,
log_rotation_hat,
log_rotation_hat_square,
rotation_angles,
eps=eps,
)
T = torch.bmm(V, log_translation[:, :, None])[:, :, 0]
transform = torch.zeros(
N, 4, 4, dtype=log_transform.dtype, device=log_transform.device
)
transform[:, :3, :3] = R
transform[:, :3, 3] = T
transform[:, 3, 3] = 1.0
return transform.permute(0, 2, 1)
def se3_log_map(
transform: torch.Tensor, eps: float = 1e-4, cos_bound: float = 1e-4
) -> torch.Tensor:
"""
Convert a batch of 4x4 transformation matrices `transform`
to a batch of 6-dimensional SE(3) logarithms of the SE(3) matrices.
See e.g. [1], Sec 9.4.2. for more detailed description.
A SE(3) matrix has the following form:
```
[ R 0 ]
[ T 1 ] ,
```
where `R` is an orthonormal 3x3 rotation matrix and `T` is a 3-D translation vector.
SE(3) matrices are commonly used to represent rigid motions or camera extrinsics.
In the SE(3) logarithmic representation SE(3) matrices are
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
The conversion from the 4x4 SE(3) matrix `transform` to the
6D representation `log_transform = [log_translation | log_rotation]`
is done as follows:
```
log_transform = log(transform)
log_translation = log_transform[3, :3]
log_rotation = inv_hat(log_transform[:3, :3])
```
where `log` is the matrix logarithm
and `inv_hat` is the inverse of the Hat operator [2].
Note that for any valid 4x4 `transform` matrix, the following identity holds:
```
se3_exp_map(se3_log_map(transform)) == transform
```
The conversion has a singularity around `(transform=I)` which is handled
by clamping controlled with the `eps` and `cos_bound` arguments.
Args:
transform: batch of SE(3) matrices of shape `(minibatch, 4, 4)`.
eps: A threshold for clipping the squared norm of the rotation logarithm
to avoid division by zero in the singular case.
cos_bound: Clamps the cosine of the rotation angle to
[-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs.
The non-finite outputs can be caused by passing small rotation angles
to the `acos` function in `so3_rotation_angle` of `so3_log_map`.
Returns:
Batch of logarithms of input SE(3) matrices
of shape `(minibatch, 6)`.
Raises:
ValueError if `transform` is of incorrect shape.
ValueError if `R` has an unexpected trace.
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
[2] https://en.wikipedia.org/wiki/Hat_operator
"""
if transform.ndim != 3:
raise ValueError("Input tensor shape has to be (N, 4, 4).")
N, dim1, dim2 = transform.shape
if dim1 != 4 or dim2 != 4:
raise ValueError("Input tensor shape has to be (N, 4, 4).")
if not torch.allclose(transform[:, :3, 3], torch.zeros_like(transform[:, :3, 3])):
raise ValueError("All elements of `transform[:, :3, 3]` should be 0.")
# log_rot is just so3_log_map of the upper left 3x3 block
R = transform[:, :3, :3].permute(0, 2, 1)
log_rotation = so3_log_map(R, eps=eps, cos_bound=cos_bound)
# log_translation is V^-1 @ T
T = transform[:, 3, :3]
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0]
return torch.cat((log_translation, log_rotation), dim=1)
def _se3_V_matrix(
log_rotation: torch.Tensor,
log_rotation_hat: torch.Tensor,
log_rotation_hat_square: torch.Tensor,
rotation_angles: torch.Tensor,
eps: float = 1e-4,
) -> torch.Tensor:
"""
A helper function that computes the "V" matrix from [1], Sec 9.4.2.
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
"""
V = (
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
+ log_rotation_hat
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
* ((1 - torch.cos(rotation_angles)) / (rotation_angles**2))[:, None, None]
+ (
log_rotation_hat_square
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles**3))[
:, None, None
]
)
)
return V
def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4):
"""
A helper function that computes the input variables to the `_se3_V_matrix`
function.
"""
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
nrms = (log_rotation**2).sum(-1)
rotation_angles = torch.clamp(nrms, eps).sqrt()
log_rotation_hat = hat(log_rotation)
log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat)
return log_rotation, log_rotation_hat, log_rotation_hat_square, rotation_angles