|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +from typing import Callable, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +from ...ops.utils import eyes |
| 7 | +from ...structures import Volumes |
| 8 | +from ...transforms import Transform3d |
| 9 | +from ..cameras import CamerasBase |
| 10 | +from .raysampling import RayBundle |
| 11 | +from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points |
| 12 | + |
| 13 | + |
| 14 | +# The implicit renderer class should be initialized with a |
| 15 | +# function for raysampling and a function for raymarching. |
| 16 | + |
| 17 | +# During the forward pass: |
| 18 | +# 1) The raysampler: |
| 19 | +# - samples rays from input cameras |
| 20 | +# - transforms the rays to world coordinates |
| 21 | +# 2) The volumetric_function (which is a callable argument of the forwad pass) |
| 22 | +# evaluates ray_densities and ray_features at the sampled ray-points. |
| 23 | +# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching |
| 24 | +# algorithm to render each ray. |
| 25 | + |
| 26 | + |
| 27 | +class ImplicitRenderer(torch.nn.Module): |
| 28 | + """ |
| 29 | + A class for rendering a batch of implicit surfaces. The class should |
| 30 | + be initialized with a raysampler and raymarcher class which both have |
| 31 | + to be a `Callable`. |
| 32 | +
|
| 33 | + VOLUMETRIC_FUNCTION |
| 34 | +
|
| 35 | + The `forward` function of the renderer accepts as input the rendering cameras as well |
| 36 | + as the `volumetric_function` `Callable`, which defines a field of opacity |
| 37 | + and feature vectors over the 3D domain of the scene. |
| 38 | +
|
| 39 | + A standard `volumetric_function` has the following signature: |
| 40 | + ``` |
| 41 | + def volumetric_function(ray_bundle: RayBundle) -> Tuple[torch.Tensor, torch.Tensor] |
| 42 | + ``` |
| 43 | + With the following arguments: |
| 44 | + `ray_bundle`: A RayBundle object containing the following variables: |
| 45 | + `rays_origins`: A tensor of shape `(minibatch, ..., 3)` denoting |
| 46 | + the origins of the rendering rays. |
| 47 | + `rays_directions`: A tensor of shape `(minibatch, ..., 3)` |
| 48 | + containing the direction vectors of rendering rays. |
| 49 | + `rays_lengths`: A tensor of shape |
| 50 | + `(minibatch, ..., num_points_per_ray)`containing the |
| 51 | + lengths at which the ray points are sampled. |
| 52 | + Calling `volumetric_function` then returns the following: |
| 53 | + `rays_densities`: A tensor of shape |
| 54 | + `(minibatch, ..., num_points_per_ray, opacity_dim)` containing |
| 55 | + the an opacity vector for each ray point. |
| 56 | + `rays_features`: A tensor of shape |
| 57 | + `(minibatch, ..., num_points_per_ray, feature_dim)` containing |
| 58 | + the an feature vector for each ray point. |
| 59 | +
|
| 60 | + Example: |
| 61 | + A simple volumetric function of a 0-centered |
| 62 | + RGB sphere with a unit diameter is defined as follows: |
| 63 | + ``` |
| 64 | + def volumetric_function( |
| 65 | + ray_bundle: RayBundle, |
| 66 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 67 | +
|
| 68 | + # first convert the ray origins, directions and lengths |
| 69 | + # to 3D ray point locations in world coords |
| 70 | + rays_points_world = ray_bundle_to_ray_points(ray_bundle) |
| 71 | +
|
| 72 | + # set the densities as an inverse sigmoid of the |
| 73 | + # ray point distance from the sphere centroid |
| 74 | + rays_densities = torch.sigmoid( |
| 75 | + -100.0 * rays_points_world.norm(dim=-1, keepdim=True) |
| 76 | + ) |
| 77 | +
|
| 78 | + # set the ray features to RGB colors proportional |
| 79 | + # to the 3D location of the projection of ray points |
| 80 | + # on the sphere surface |
| 81 | + rays_features = torch.nn.functional.normalize( |
| 82 | + rays_points_world, dim=-1 |
| 83 | + ) * 0.5 + 0.5 |
| 84 | +
|
| 85 | + return rays_densities, rays_features |
| 86 | + ``` |
| 87 | + """ |
| 88 | + |
| 89 | + def __init__(self, raysampler: Callable, raymarcher: Callable): |
| 90 | + """ |
| 91 | + Args: |
| 92 | + raysampler: A `Callable` that takes as input scene cameras |
| 93 | + (an instance of `CamerasBase`) and returns a `RayBundle` that |
| 94 | + describes the rays emitted from the cameras. |
| 95 | + raymarcher: A `Callable` that receives the response of the |
| 96 | + `volumetric_function` (an input to `self.forward`) evaluated |
| 97 | + along the sampled rays, and renders the rays with a |
| 98 | + ray-marching algorithm. |
| 99 | + """ |
| 100 | + super().__init__() |
| 101 | + |
| 102 | + if not callable(raysampler): |
| 103 | + raise ValueError('"raysampler" has to be a "Callable" object.') |
| 104 | + if not callable(raymarcher): |
| 105 | + raise ValueError('"raymarcher" has to be a "Callable" object.') |
| 106 | + |
| 107 | + self.raysampler = raysampler |
| 108 | + self.raymarcher = raymarcher |
| 109 | + |
| 110 | + def forward( |
| 111 | + self, cameras: CamerasBase, volumetric_function: Callable, **kwargs |
| 112 | + ) -> Tuple[torch.Tensor, RayBundle]: |
| 113 | + """ |
| 114 | + Render a batch of images using a volumetric function |
| 115 | + represented as a callable (e.g. a Pytorch module). |
| 116 | +
|
| 117 | + Args: |
| 118 | + cameras: A batch of cameras that render the scene. A `self.raysampler` |
| 119 | + takes the cameras as input and samples rays that pass through the |
| 120 | + domain of the volumentric function. |
| 121 | + volumetric_function: A `Callable` that accepts the parametrizations |
| 122 | + of the rendering rays and returns the densities and features |
| 123 | + at the respective 3D of the rendering rays. Please refer to |
| 124 | + the main class documentation for details. |
| 125 | +
|
| 126 | + Returns: |
| 127 | + images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)` |
| 128 | + containing the result of the rendering. |
| 129 | + ray_bundle: A `RayBundle` containing the parametrizations of the |
| 130 | + sampled rendering rays. |
| 131 | + """ |
| 132 | + |
| 133 | + if not callable(volumetric_function): |
| 134 | + raise ValueError('"volumetric_function" has to be a "Callable" object.') |
| 135 | + |
| 136 | + # first call the ray sampler that returns the RayBundle parametrizing |
| 137 | + # the rendering rays. |
| 138 | + ray_bundle = self.raysampler( |
| 139 | + cameras=cameras, volumetric_function=volumetric_function, **kwargs |
| 140 | + ) |
| 141 | + # ray_bundle.origins - minibatch x ... x 3 |
| 142 | + # ray_bundle.directions - minibatch x ... x 3 |
| 143 | + # ray_bundle.lengths - minibatch x ... x n_pts_per_ray |
| 144 | + # ray_bundle.xys - minibatch x ... x 2 |
| 145 | + |
| 146 | + # given sampled rays, call the volumetric function that |
| 147 | + # evaluates the densities and features at the locations of the |
| 148 | + # ray points |
| 149 | + rays_densities, rays_features = volumetric_function( |
| 150 | + ray_bundle=ray_bundle, cameras=cameras, **kwargs |
| 151 | + ) |
| 152 | + # ray_densities - minibatch x ... x n_pts_per_ray x density_dim |
| 153 | + # ray_features - minibatch x ... x n_pts_per_ray x feature_dim |
| 154 | + |
| 155 | + # finally, march along the sampled rays to obtain the renders |
| 156 | + images = self.raymarcher( |
| 157 | + rays_densities=rays_densities, |
| 158 | + rays_features=rays_features, |
| 159 | + ray_bundle=ray_bundle, |
| 160 | + **kwargs |
| 161 | + ) |
| 162 | + # images - minibatch x ... x (feature_dim + opacity_dim) |
| 163 | + |
| 164 | + return images, ray_bundle |
| 165 | + |
| 166 | + |
| 167 | +# The volume renderer class should be initialized with a |
| 168 | +# function for raysampling and a function for raymarching. |
| 169 | + |
| 170 | +# During the forward pass: |
| 171 | +# 1) The raysampler: |
| 172 | +# - samples rays from input cameras |
| 173 | +# - transforms the rays to world coordinates |
| 174 | +# 2) The scene volumes (which are an argument of the forward function) |
| 175 | +# are then sampled at the locations of the ray-points to generate |
| 176 | +# ray_densities and ray_features. |
| 177 | +# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching |
| 178 | +# algorithm to render each ray. |
| 179 | + |
| 180 | + |
| 181 | +class VolumeRenderer(torch.nn.Module): |
| 182 | + """ |
| 183 | + A class for rendering a batch of Volumes. The class should |
| 184 | + be initialized with a raysampler and a raymarcher class which both have |
| 185 | + to be a `Callable`. |
| 186 | + """ |
| 187 | + |
| 188 | + def __init__( |
| 189 | + self, raysampler: Callable, raymarcher: Callable, sample_mode: str = "bilinear" |
| 190 | + ): |
| 191 | + """ |
| 192 | + Args: |
| 193 | + raysampler: A `Callable` that takes as input scene cameras |
| 194 | + (an instance of `CamerasBase`) and returns a `RayBundle` that |
| 195 | + describes the rays emitted from the cameras. |
| 196 | + raymarcher: A `Callable` that receives the `volumes` |
| 197 | + (an instance of `Volumes` input to `self.forward`) |
| 198 | + sampled at the ray-points, and renders the rays with a |
| 199 | + ray-marching algorithm. |
| 200 | + sample_mode: Defines the algorithm used to sample the volumetric |
| 201 | + voxel grid. Can be either "bilinear" or "nearest". |
| 202 | + """ |
| 203 | + super().__init__() |
| 204 | + |
| 205 | + self.renderer = ImplicitRenderer(raysampler, raymarcher) |
| 206 | + self._sample_mode = sample_mode |
| 207 | + |
| 208 | + def forward( |
| 209 | + self, cameras: CamerasBase, volumes: Volumes, **kwargs |
| 210 | + ) -> Tuple[torch.Tensor, RayBundle]: |
| 211 | + """ |
| 212 | + Render a batch of images using raymarching over rays cast through |
| 213 | + input `Volumes`. |
| 214 | +
|
| 215 | + Args: |
| 216 | + cameras: A batch of cameras that render the scene. A `self.raysampler` |
| 217 | + takes the cameras as input and samples rays that pass through the |
| 218 | + domain of the volumentric function. |
| 219 | + volumes: An instance of the `Volumes` class representing a |
| 220 | + batch of volumes that are being rendered. |
| 221 | +
|
| 222 | + Returns: |
| 223 | + images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)` |
| 224 | + containing the result of the rendering. |
| 225 | + ray_bundle: A `RayBundle` containing the parametrizations of the |
| 226 | + sampled rendering rays. |
| 227 | + """ |
| 228 | + volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode) |
| 229 | + return self.renderer( |
| 230 | + cameras=cameras, volumetric_function=volumetric_function, **kwargs |
| 231 | + ) |
| 232 | + |
| 233 | + |
| 234 | +class VolumeSampler(torch.nn.Module): |
| 235 | + """ |
| 236 | + A class that allows to sample a batch of volumes `Volumes` |
| 237 | + at 3D points sampled along projection rays. |
| 238 | + """ |
| 239 | + |
| 240 | + def __init__(self, volumes: Volumes, sample_mode: str = "bilinear"): |
| 241 | + """ |
| 242 | + Args: |
| 243 | + volumes: An instance of the `Volumes` class representing a |
| 244 | + batch if volumes that are being rendered. |
| 245 | + sample_mode: Defines the algorithm used to sample the volumetric |
| 246 | + voxel grid. Can be either "bilinear" or "nearest". |
| 247 | + """ |
| 248 | + super().__init__() |
| 249 | + if not isinstance(volumes, Volumes): |
| 250 | + raise ValueError("'volumes' have to be an instance of the 'Volumes' class.") |
| 251 | + self._volumes = volumes |
| 252 | + self._sample_mode = sample_mode |
| 253 | + |
| 254 | + def _get_ray_directions_transform(self): |
| 255 | + """ |
| 256 | + Compose the ray-directions transform by removing the translation component |
| 257 | + from the volume global-to-local coords transform. |
| 258 | + """ |
| 259 | + world2local = self._volumes.get_world_to_local_coords_transform().get_matrix() |
| 260 | + directions_transform_matrix = eyes( |
| 261 | + 4, |
| 262 | + N=world2local.shape[0], |
| 263 | + device=world2local.device, |
| 264 | + dtype=world2local.dtype, |
| 265 | + ) |
| 266 | + directions_transform_matrix[:, :3, :3] = world2local[:, :3, :3] |
| 267 | + directions_transform = Transform3d(matrix=directions_transform_matrix) |
| 268 | + return directions_transform |
| 269 | + |
| 270 | + def forward( |
| 271 | + self, ray_bundle: RayBundle, **kwargs |
| 272 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 273 | + """ |
| 274 | + Given an input ray parametrization, the forward function samples |
| 275 | + `self._volumes` at the respective 3D ray-points. |
| 276 | +
|
| 277 | + Args: |
| 278 | + ray_bundle: A RayBundle object with the following fields: |
| 279 | + rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the |
| 280 | + origins of the sampling rays in world coords. |
| 281 | + rays_directions_world: A tensor of shape `(minibatch, ..., 3)` |
| 282 | + containing the direction vectors of sampling rays in world coords. |
| 283 | + rays_lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` |
| 284 | + containing the lengths at which the rays are sampled. |
| 285 | +
|
| 286 | + Returns: |
| 287 | + rays_densities: A tensor of shape |
| 288 | + `(minibatch, ..., num_points_per_ray, opacity_dim)` containing the |
| 289 | + densitity vectors sampled from the volume at the locations of |
| 290 | + the ray points. |
| 291 | + rays_features: A tensor of shape |
| 292 | + `(minibatch, ..., num_points_per_ray, feature_dim)` containing the |
| 293 | + feature vectors sampled from the volume at the locations of |
| 294 | + the ray points. |
| 295 | + """ |
| 296 | + |
| 297 | + # take out the interesting parts of ray_bundle |
| 298 | + rays_origins_world = ray_bundle.origins |
| 299 | + rays_directions_world = ray_bundle.directions |
| 300 | + rays_lengths = ray_bundle.lengths |
| 301 | + |
| 302 | + # validate the inputs |
| 303 | + _validate_ray_bundle_variables( |
| 304 | + rays_origins_world, rays_directions_world, rays_lengths |
| 305 | + ) |
| 306 | + if self._volumes.densities().shape[0] != rays_origins_world.shape[0]: |
| 307 | + raise ValueError("Input volumes have to have the same batch size as rays.") |
| 308 | + |
| 309 | + ######################################################### |
| 310 | + # 1) convert the origins/directions to the local coords # |
| 311 | + ######################################################### |
| 312 | + |
| 313 | + # origins are mapped with the world_to_local transform of the volumes |
| 314 | + rays_origins_local = self._volumes.world_to_local_coords(rays_origins_world) |
| 315 | + |
| 316 | + # obtain the Transform3d object that transforms ray directions to local coords |
| 317 | + directions_transform = self._get_ray_directions_transform() |
| 318 | + |
| 319 | + # transform the directions to the local coords |
| 320 | + rays_directions_local = directions_transform.transform_points( |
| 321 | + rays_directions_world.view(rays_lengths.shape[0], -1, 3) |
| 322 | + ).view(rays_directions_world.shape) |
| 323 | + |
| 324 | + ############################ |
| 325 | + # 2) obtain the ray points # |
| 326 | + ############################ |
| 327 | + |
| 328 | + # this op produces a fairly big tensor (minibatch, ..., n_samples_per_ray, 3) |
| 329 | + rays_points_local = ray_bundle_variables_to_ray_points( |
| 330 | + rays_origins_local, rays_directions_local, rays_lengths |
| 331 | + ) |
| 332 | + |
| 333 | + ######################## |
| 334 | + # 3) sample the volume # |
| 335 | + ######################## |
| 336 | + |
| 337 | + # generate the tensor for sampling |
| 338 | + volumes_densities = self._volumes.densities() |
| 339 | + dim_density = volumes_densities.shape[1] |
| 340 | + volumes_features = self._volumes.features() |
| 341 | + # adjust the volumes_features variable in case we have a feature-less volume |
| 342 | + if volumes_features is None: |
| 343 | + dim_feature = 0 |
| 344 | + data_to_sample = volumes_densities |
| 345 | + else: |
| 346 | + dim_feature = volumes_features.shape[1] |
| 347 | + data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1) |
| 348 | + |
| 349 | + # reshape to a size which grid_sample likes |
| 350 | + rays_points_local_flat = rays_points_local.view( |
| 351 | + rays_points_local.shape[0], -1, 1, 1, 3 |
| 352 | + ) |
| 353 | + |
| 354 | + # run the grid sampler |
| 355 | + data_sampled = torch.nn.functional.grid_sample( |
| 356 | + data_to_sample, |
| 357 | + rays_points_local_flat, |
| 358 | + align_corners=True, |
| 359 | + mode=self._sample_mode, |
| 360 | + ) |
| 361 | + |
| 362 | + # permute the dimensions & reshape after sampling |
| 363 | + data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view( |
| 364 | + *rays_points_local.shape[:-1], data_sampled.shape[1] |
| 365 | + ) |
| 366 | + |
| 367 | + # split back to densities and features |
| 368 | + rays_densities, rays_features = data_sampled.split( |
| 369 | + [dim_density, dim_feature], dim=-1 |
| 370 | + ) |
| 371 | + |
| 372 | + return rays_densities, rays_features |
0 commit comments