|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 3 | +""" |
| 4 | +This example demonstrates camera parameter optimization with the plain |
| 5 | +pulsar interface. For this, a reference image has been pre-generated |
| 6 | +(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_cam.png`). |
| 7 | +The same scene parameterization is loaded and the camera parameters |
| 8 | +distorted. Gradient-based optimization is used to converge towards the |
| 9 | +original camera parameters. |
| 10 | +""" |
| 11 | +from os import path |
| 12 | + |
| 13 | +import cv2 |
| 14 | +import imageio |
| 15 | +import numpy as np |
| 16 | +import torch |
| 17 | +from pytorch3d.renderer.points.pulsar import Renderer |
| 18 | +from torch import nn, optim |
| 19 | + |
| 20 | + |
| 21 | +n_points = 20 |
| 22 | +width = 1_000 |
| 23 | +height = 1_000 |
| 24 | +device = torch.device("cuda") |
| 25 | + |
| 26 | + |
| 27 | +class SceneModel(nn.Module): |
| 28 | + """ |
| 29 | + A simple scene model to demonstrate use of pulsar in PyTorch modules. |
| 30 | +
|
| 31 | + The scene model is parameterized with sphere locations (vert_pos), |
| 32 | + channel content (vert_col), radiuses (vert_rad), camera position (cam_pos), |
| 33 | + camera rotation (cam_rot) and sensor focal length and width (cam_sensor). |
| 34 | +
|
| 35 | + The forward method of the model renders this scene description. Any |
| 36 | + of these parameters could instead be passed as inputs to the forward |
| 37 | + method and come from a different model. |
| 38 | + """ |
| 39 | + |
| 40 | + def __init__(self): |
| 41 | + super(SceneModel, self).__init__() |
| 42 | + self.gamma = 0.1 |
| 43 | + # Points. |
| 44 | + torch.manual_seed(1) |
| 45 | + vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 |
| 46 | + vert_pos[:, 2] += 25.0 |
| 47 | + vert_pos[:, :2] -= 5.0 |
| 48 | + self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=False)) |
| 49 | + self.register_parameter( |
| 50 | + "vert_col", |
| 51 | + nn.Parameter( |
| 52 | + torch.rand(n_points, 3, dtype=torch.float32), requires_grad=False |
| 53 | + ), |
| 54 | + ) |
| 55 | + self.register_parameter( |
| 56 | + "vert_rad", |
| 57 | + nn.Parameter( |
| 58 | + torch.rand(n_points, dtype=torch.float32), requires_grad=False |
| 59 | + ), |
| 60 | + ) |
| 61 | + self.register_parameter( |
| 62 | + "cam_pos", |
| 63 | + nn.Parameter( |
| 64 | + torch.tensor([0.1, 0.1, 0.0], dtype=torch.float32), requires_grad=True |
| 65 | + ), |
| 66 | + ) |
| 67 | + self.register_parameter( |
| 68 | + "cam_rot", |
| 69 | + nn.Parameter( |
| 70 | + torch.tensor( |
| 71 | + [ |
| 72 | + # We're using the 6D rot. representation for better gradients. |
| 73 | + 0.9995, |
| 74 | + 0.0300445, |
| 75 | + -0.0098482, |
| 76 | + -0.0299445, |
| 77 | + 0.9995, |
| 78 | + 0.0101482, |
| 79 | + ], |
| 80 | + dtype=torch.float32, |
| 81 | + ), |
| 82 | + requires_grad=True, |
| 83 | + ), |
| 84 | + ) |
| 85 | + self.register_parameter( |
| 86 | + "cam_sensor", |
| 87 | + nn.Parameter( |
| 88 | + torch.tensor([4.8, 1.8], dtype=torch.float32), requires_grad=True |
| 89 | + ), |
| 90 | + ) |
| 91 | + self.renderer = Renderer(width, height, n_points) |
| 92 | + |
| 93 | + def forward(self): |
| 94 | + return self.renderer.forward( |
| 95 | + self.vert_pos, |
| 96 | + self.vert_col, |
| 97 | + self.vert_rad, |
| 98 | + torch.cat([self.cam_pos, self.cam_rot, self.cam_sensor]), |
| 99 | + self.gamma, |
| 100 | + 45.0, |
| 101 | + ) |
| 102 | + |
| 103 | + |
| 104 | +# Load reference. |
| 105 | +ref = ( |
| 106 | + torch.from_numpy( |
| 107 | + imageio.imread( |
| 108 | + "../../tests/pulsar/reference/examples_TestRenderer_test_cam.png" |
| 109 | + ) |
| 110 | + ).to(torch.float32) |
| 111 | + / 255.0 |
| 112 | +).to(device) |
| 113 | +# Set up model. |
| 114 | +model = SceneModel().to(device) |
| 115 | +# Optimizer. |
| 116 | +optimizer = optim.SGD( |
| 117 | + [ |
| 118 | + {"params": [model.cam_pos], "lr": 1e-4}, # 1e-3 |
| 119 | + {"params": [model.cam_rot], "lr": 5e-6}, |
| 120 | + {"params": [model.cam_sensor], "lr": 1e-4}, |
| 121 | + ] |
| 122 | +) |
| 123 | + |
| 124 | +print("Writing video to `%s`." % (path.abspath("cam.gif"))) |
| 125 | +writer = imageio.get_writer("cam.gif", format="gif", fps=25) |
| 126 | + |
| 127 | +# Optimize. |
| 128 | +for i in range(300): |
| 129 | + optimizer.zero_grad() |
| 130 | + result = model() |
| 131 | + # Visualize. |
| 132 | + result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8) |
| 133 | + cv2.imshow("opt", result_im[:, :, ::-1]) |
| 134 | + writer.append_data(result_im) |
| 135 | + overlay_img = np.ascontiguousarray( |
| 136 | + ((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[ |
| 137 | + :, :, ::-1 |
| 138 | + ] |
| 139 | + ) |
| 140 | + overlay_img = cv2.putText( |
| 141 | + overlay_img, |
| 142 | + "Step %d" % (i), |
| 143 | + (10, 40), |
| 144 | + cv2.FONT_HERSHEY_SIMPLEX, |
| 145 | + 1, |
| 146 | + (0, 0, 0), |
| 147 | + 2, |
| 148 | + cv2.LINE_AA, |
| 149 | + False, |
| 150 | + ) |
| 151 | + cv2.imshow("overlay", overlay_img) |
| 152 | + cv2.waitKey(1) |
| 153 | + # Update. |
| 154 | + loss = ((result - ref) ** 2).sum() |
| 155 | + print("loss {}: {}".format(i, loss.item())) |
| 156 | + loss.backward() |
| 157 | + optimizer.step() |
| 158 | +writer.close() |
0 commit comments