Skip to content

Commit 75432a0

Browse files
classnerfacebook-github-bot
authored andcommitted
Add OpenCV camera conversion; fix bug for camera unified PyTorch3D interface.
Summary: This commit adds a new camera conversion function for OpenCV style parameters to Pulsar parameters to the library. Using this function it addresses a bug reported here: https://fb.workplace.com/groups/629644647557365/posts/1079637302558095, by using the PyTorch3D->OpenCV->Pulsar chain instead of the original direct conversion function. Both conversions are well-tested and an additional test for the full chain has been added, resulting in a more reliable solution requiring less code. Reviewed By: patricklabatut Differential Revision: D29322106 fbshipit-source-id: 13df13c2e48f628f75d9f44f19ff7f1646fb7ebd
1 parent fef5bcd commit 75432a0

8 files changed

+275
-32
lines changed

pytorch3d/renderer/points/pulsar/unified.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313

14-
from ....transforms import matrix_to_rotation_6d
14+
from ....utils import pulsar_from_cameras_projection
1515
from ...cameras import (
1616
FoVOrthographicCameras,
1717
FoVPerspectiveCameras,
@@ -102,7 +102,7 @@ def __init__(
102102
height=height,
103103
max_num_balls=max_num_spheres,
104104
orthogonal_projection=orthogonal_projection,
105-
right_handed_system=True,
105+
right_handed_system=False,
106106
n_channels=n_channels,
107107
**kwargs,
108108
)
@@ -359,24 +359,28 @@ def _extract_intrinsics( # noqa: C901
359359
def _extract_extrinsics(
360360
self, kwargs, cloud_idx
361361
) -> Tuple[torch.Tensor, torch.Tensor]:
362+
"""
363+
Extract the extrinsic information from the kwargs for a specific point cloud.
364+
365+
Instead of implementing a direct translation from the PyTorch3D to the Pulsar
366+
camera model, we chain the two conversions of PyTorch3D->OpenCV and
367+
OpenCV->Pulsar for better maintainability (PyTorch3D->OpenCV is maintained and
368+
tested by the core PyTorch3D team, whereas OpenCV->Pulsar is maintained and
369+
tested by the Pulsar team).
370+
"""
362371
# Shorthand:
363372
cameras = self.rasterizer.cameras
364373
R = kwargs.get("R", cameras.R)[cloud_idx]
365374
T = kwargs.get("T", cameras.T)[cloud_idx]
366-
norm_mat = torch.tensor(
367-
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
368-
dtype=torch.float32,
369-
device=R.device,
375+
tmp_cams = PerspectiveCameras(
376+
R=R.unsqueeze(0), T=T.unsqueeze(0), device=R.device
370377
)
371-
cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...]).permute((0, 2, 1))
372-
norm_mat = torch.tensor(
373-
[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
374-
dtype=torch.float32,
375-
device=R.device,
378+
size_tensor = torch.tensor(
379+
[[self.renderer._renderer.height, self.renderer._renderer.width]]
376380
)
377-
cam_rot = torch.matmul(norm_mat, cam_rot)
378-
cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None]))
379-
cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot))
381+
pulsar_cam = pulsar_from_cameras_projection(tmp_cams, size_tensor)
382+
cam_pos = pulsar_cam[0, :3]
383+
cam_rot = pulsar_cam[0, 3:9]
380384
return cam_pos, cam_rot
381385

382386
def _get_vert_rad(
@@ -547,15 +551,17 @@ def forward(self, point_clouds, **kwargs) -> torch.Tensor:
547551
otherargs["bg_col"] = bg_col
548552
# Go!
549553
images.append(
550-
self.renderer(
551-
vert_pos=vert_pos,
552-
vert_col=vert_col,
553-
vert_rad=vert_rad,
554-
cam_params=cam_params,
555-
gamma=gamma,
556-
max_depth=zfar,
557-
min_depth=znear,
558-
**otherargs,
554+
torch.flipud(
555+
self.renderer(
556+
vert_pos=vert_pos,
557+
vert_col=vert_col,
558+
vert_rad=vert_rad,
559+
cam_params=cam_params,
560+
gamma=gamma,
561+
max_depth=zfar,
562+
min_depth=znear,
563+
**otherargs,
564+
)
559565
)
560566
)
561567
return torch.stack(images, dim=0)

pytorch3d/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from .camera_conversions import (
88
cameras_from_opencv_projection,
99
opencv_from_cameras_projection,
10+
pulsar_from_opencv_projection,
11+
pulsar_from_cameras_projection,
1012
)
1113
from .ico_sphere import ico_sphere
1214
from .torus import torus

pytorch3d/utils/camera_conversions.py

Lines changed: 162 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
from typing import Tuple
89

910
import torch
1011

1112
from ..renderer import PerspectiveCameras
12-
from ..transforms import so3_exp_map, so3_log_map
13+
from ..transforms import matrix_to_rotation_6d
14+
15+
16+
LOGGER = logging.getLogger(__name__)
1317

1418

1519
def cameras_from_opencv_projection(
@@ -54,7 +58,6 @@ def cameras_from_opencv_projection(
5458
Returns:
5559
cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention.
5660
"""
57-
5861
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
5962
principal_point = camera_matrix[:, :2, 2]
6063

@@ -68,7 +71,7 @@ def cameras_from_opencv_projection(
6871
# For R, T we flip x, y axes (opencv screen space has an opposite
6972
# orientation of screen axes).
7073
# We also transpose R (opencv multiplies points from the opposite=left side).
71-
R_pytorch3d = R.permute(0, 2, 1)
74+
R_pytorch3d = R.clone().permute(0, 2, 1)
7275
T_pytorch3d = tvec.clone()
7376
R_pytorch3d[:, :, :2] *= -1
7477
T_pytorch3d[:, :2] *= -1
@@ -103,20 +106,22 @@ def opencv_from_cameras_projection(
103106
cameras: A batch of `N` cameras in the PyTorch3D convention.
104107
image_size: A tensor of shape `(N, 2)` containing the sizes of the images
105108
(height, width) attached to each camera.
109+
return_as_rotmat (bool): If set to True, return the full 3x3 rotation
110+
matrices. Otherwise, return an axis-angle vector (default).
106111
107112
Returns:
108113
R: A batch of rotation matrices of shape `(N, 3, 3)`.
109114
tvec: A batch of translation vectors of shape `(N, 3)`.
110115
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
111116
"""
112-
R_pytorch3d = cameras.R
113-
T_pytorch3d = cameras.T
117+
R_pytorch3d = cameras.R.clone() # pyre-ignore
118+
T_pytorch3d = cameras.T.clone() # pyre-ignore
114119
focal_pytorch3d = cameras.focal_length
115120
p0_pytorch3d = cameras.principal_point
116-
T_pytorch3d[:, :2] *= -1 # pyre-ignore
117-
R_pytorch3d[:, :, :2] *= -1 # pyre-ignore
118-
tvec = T_pytorch3d.clone() # pyre-ignore
119-
R = R_pytorch3d.permute(0, 2, 1) # pyre-ignore
121+
T_pytorch3d[:, :2] *= -1
122+
R_pytorch3d[:, :, :2] *= -1
123+
tvec = T_pytorch3d
124+
R = R_pytorch3d.permute(0, 2, 1)
120125

121126
# Retype the image_size correctly and flip to width, height.
122127
image_size_wh = image_size.to(R).flip(dims=(1,))
@@ -130,3 +135,151 @@ def opencv_from_cameras_projection(
130135
camera_matrix[:, 0, 0] = focal_length[:, 0]
131136
camera_matrix[:, 1, 1] = focal_length[:, 1]
132137
return R, tvec, camera_matrix
138+
139+
140+
def pulsar_from_opencv_projection(
141+
R: torch.Tensor,
142+
tvec: torch.Tensor,
143+
camera_matrix: torch.Tensor,
144+
image_size: torch.Tensor,
145+
znear: float = 0.1,
146+
) -> torch.Tensor:
147+
"""
148+
Convert OpenCV style camera parameters to Pulsar style camera parameters.
149+
150+
Note:
151+
* Pulsar does NOT support different focal lengths for x and y.
152+
For conversion, we use the average of fx and fy.
153+
* The Pulsar renderer MUST use a left-handed coordinate system for this
154+
mapping to work.
155+
* The resulting image will be vertically flipped - which has to be
156+
addressed AFTER rendering by the user.
157+
* The parameters `R, tvec, camera_matrix` correspond to the outputs
158+
of `cv2.decomposeProjectionMatrix`.
159+
160+
Args:
161+
R: A batch of rotation matrices of shape `(N, 3, 3)`.
162+
tvec: A batch of translation vectors of shape `(N, 3)`.
163+
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
164+
image_size: A tensor of shape `(N, 2)` containing the sizes of the images
165+
(height, width) attached to each camera.
166+
znear (float): The near clipping value to use for Pulsar.
167+
168+
Returns:
169+
cameras_pulsar: A batch of `N` Pulsar camera vectors in the Pulsar
170+
convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
171+
c_x, c_y).
172+
"""
173+
assert len(camera_matrix.size()) == 3, "This function requires batched inputs!"
174+
assert len(R.size()) == 3, "This function requires batched inputs!"
175+
assert len(tvec.size()) in (2, 3), "This function reuqires batched inputs!"
176+
177+
# Validate parameters.
178+
image_size_wh = image_size.to(R).flip(dims=(1,))
179+
assert torch.all(
180+
image_size_wh > 0
181+
), "height and width must be positive but min is: %s" % (
182+
str(image_size_wh.min().item())
183+
)
184+
assert (
185+
camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3
186+
), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % (
187+
camera_matrix.size(1),
188+
camera_matrix.size(2),
189+
)
190+
assert (
191+
R.size(1) == 3 and R.size(2) == 3
192+
), "Incorrect R shape: expected 3x3 but got %dx%d" % (
193+
R.size(1),
194+
R.size(2),
195+
)
196+
if len(tvec.size()) == 2:
197+
tvec = tvec.unsqueeze(2)
198+
assert (
199+
tvec.size(1) == 3 and tvec.size(2) == 1
200+
), "Incorrect tvec shape: expected 3x1 but got %dx%d" % (
201+
tvec.size(1),
202+
tvec.size(2),
203+
)
204+
# Check batch size.
205+
batch_size = camera_matrix.size(0)
206+
assert R.size(0) == batch_size, "Expected R to have batch size %d. Has size %d." % (
207+
batch_size,
208+
R.size(0),
209+
)
210+
assert (
211+
tvec.size(0) == batch_size
212+
), "Expected tvec to have batch size %d. Has size %d." % (
213+
batch_size,
214+
tvec.size(0),
215+
)
216+
# Check image sizes.
217+
image_w = image_size_wh[0, 0]
218+
image_h = image_size_wh[0, 1]
219+
assert torch.all(
220+
image_size_wh[:, 0] == image_w
221+
), "All images in a batch must have the same width!"
222+
assert torch.all(
223+
image_size_wh[:, 1] == image_h
224+
), "All images in a batch must have the same height!"
225+
# Focal length.
226+
fx = camera_matrix[:, 0, 0].unsqueeze(1)
227+
fy = camera_matrix[:, 1, 1].unsqueeze(1)
228+
# Check that we introduce less than 1% error by averaging the focal lengths.
229+
fx_y = fx / fy
230+
if torch.any(fx_y > 1.01) or torch.any(fx_y < 0.99):
231+
LOGGER.warning(
232+
"Pulsar only supports a single focal lengths. For converting OpenCV "
233+
"focal lengths, we average them for x and y directions. "
234+
"The focal lengths for x and y you provided differ by more than 1%, "
235+
"which means this could introduce a noticeable error."
236+
)
237+
f = (fx + fy) / 2
238+
# Normalize f into normalized device coordinates.
239+
focal_length_px = f / image_w
240+
# Transfer into focal_length and sensor_width.
241+
focal_length = torch.tensor([znear - 1e-5], dtype=torch.float32, device=R.device)
242+
focal_length = focal_length[None, :].repeat(batch_size, 1)
243+
sensor_width = focal_length / focal_length_px
244+
# Principal point.
245+
cx = camera_matrix[:, 0, 2].unsqueeze(1)
246+
cy = camera_matrix[:, 1, 2].unsqueeze(1)
247+
# Transfer principal point offset into centered offset.
248+
cx = -(cx - image_w / 2)
249+
cy = cy - image_h / 2
250+
# Concatenate to final vector.
251+
param = torch.cat([focal_length, sensor_width, cx, cy], dim=1)
252+
R_trans = R.permute(0, 2, 1)
253+
cam_pos = -torch.bmm(R_trans, tvec).squeeze(2)
254+
cam_rot = matrix_to_rotation_6d(R_trans)
255+
cam_params = torch.cat([cam_pos, cam_rot, param], dim=1)
256+
return cam_params
257+
258+
259+
def pulsar_from_cameras_projection(
260+
cameras: PerspectiveCameras,
261+
image_size: torch.Tensor,
262+
) -> torch.Tensor:
263+
"""
264+
Convert PyTorch3D `PerspectiveCameras` to Pulsar style camera parameters.
265+
266+
Note:
267+
* Pulsar does NOT support different focal lengths for x and y.
268+
For conversion, we use the average of fx and fy.
269+
* The Pulsar renderer MUST use a left-handed coordinate system for this
270+
mapping to work.
271+
* The resulting image will be vertically flipped - which has to be
272+
addressed AFTER rendering by the user.
273+
274+
Args:
275+
cameras: A batch of `N` cameras in the PyTorch3D convention.
276+
image_size: A tensor of shape `(N, 2)` containing the sizes of the images
277+
(height, width) attached to each camera.
278+
279+
Returns:
280+
cameras_pulsar: A batch of `N` Pulsar camera vectors in the Pulsar
281+
convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
282+
c_x, c_y).
283+
"""
284+
opencv_R, opencv_T, opencv_K = opencv_from_cameras_projection(cameras, image_size)
285+
return pulsar_from_opencv_projection(opencv_R, opencv_T, opencv_K, image_size)
Loading
Loading
Loading
Loading

0 commit comments

Comments
 (0)