Skip to content

Commit b466c38

Browse files
davnov134facebook-github-bot
authored andcommitted
Implicit/Volume renderer
Summary: Implements the `ImplicitRenderer` and `VolumeRenderer`. Reviewed By: gkioxari Differential Revision: D24418791 fbshipit-source-id: 127f21186d8e210895db1dcd0681f09f230d81a4
1 parent e6a32bf commit b466c38

File tree

8 files changed

+1575
-3
lines changed

8 files changed

+1575
-3
lines changed

pytorch3d/renderer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@
2424
AbsorptionOnlyRaymarcher,
2525
EmissionAbsorptionRaymarcher,
2626
GridRaysampler,
27+
ImplicitRenderer,
2728
MonteCarloRaysampler,
2829
NDCGridRaysampler,
2930
RayBundle,
31+
VolumeRenderer,
32+
VolumeSampler,
3033
ray_bundle_to_ray_points,
3134
ray_bundle_variables_to_ray_points,
3235
)

pytorch3d/renderer/implicit/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
44
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
5+
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
56
from .utils import (
67
RayBundle,
78
ray_bundle_to_ray_points,
Lines changed: 372 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
from typing import Callable, Tuple
3+
4+
import torch
5+
6+
from ...ops.utils import eyes
7+
from ...structures import Volumes
8+
from ...transforms import Transform3d
9+
from ..cameras import CamerasBase
10+
from .raysampling import RayBundle
11+
from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points
12+
13+
14+
# The implicit renderer class should be initialized with a
15+
# function for raysampling and a function for raymarching.
16+
17+
# During the forward pass:
18+
# 1) The raysampler:
19+
# - samples rays from input cameras
20+
# - transforms the rays to world coordinates
21+
# 2) The volumetric_function (which is a callable argument of the forwad pass)
22+
# evaluates ray_densities and ray_features at the sampled ray-points.
23+
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching
24+
# algorithm to render each ray.
25+
26+
27+
class ImplicitRenderer(torch.nn.Module):
28+
"""
29+
A class for rendering a batch of implicit surfaces. The class should
30+
be initialized with a raysampler and raymarcher class which both have
31+
to be a `Callable`.
32+
33+
VOLUMETRIC_FUNCTION
34+
35+
The `forward` function of the renderer accepts as input the rendering cameras as well
36+
as the `volumetric_function` `Callable`, which defines a field of opacity
37+
and feature vectors over the 3D domain of the scene.
38+
39+
A standard `volumetric_function` has the following signature:
40+
```
41+
def volumetric_function(ray_bundle: RayBundle) -> Tuple[torch.Tensor, torch.Tensor]
42+
```
43+
With the following arguments:
44+
`ray_bundle`: A RayBundle object containing the following variables:
45+
`rays_origins`: A tensor of shape `(minibatch, ..., 3)` denoting
46+
the origins of the rendering rays.
47+
`rays_directions`: A tensor of shape `(minibatch, ..., 3)`
48+
containing the direction vectors of rendering rays.
49+
`rays_lengths`: A tensor of shape
50+
`(minibatch, ..., num_points_per_ray)`containing the
51+
lengths at which the ray points are sampled.
52+
Calling `volumetric_function` then returns the following:
53+
`rays_densities`: A tensor of shape
54+
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing
55+
the an opacity vector for each ray point.
56+
`rays_features`: A tensor of shape
57+
`(minibatch, ..., num_points_per_ray, feature_dim)` containing
58+
the an feature vector for each ray point.
59+
60+
Example:
61+
A simple volumetric function of a 0-centered
62+
RGB sphere with a unit diameter is defined as follows:
63+
```
64+
def volumetric_function(
65+
ray_bundle: RayBundle,
66+
) -> Tuple[torch.Tensor, torch.Tensor]:
67+
68+
# first convert the ray origins, directions and lengths
69+
# to 3D ray point locations in world coords
70+
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
71+
72+
# set the densities as an inverse sigmoid of the
73+
# ray point distance from the sphere centroid
74+
rays_densities = torch.sigmoid(
75+
-100.0 * rays_points_world.norm(dim=-1, keepdim=True)
76+
)
77+
78+
# set the ray features to RGB colors proportional
79+
# to the 3D location of the projection of ray points
80+
# on the sphere surface
81+
rays_features = torch.nn.functional.normalize(
82+
rays_points_world, dim=-1
83+
) * 0.5 + 0.5
84+
85+
return rays_densities, rays_features
86+
```
87+
"""
88+
89+
def __init__(self, raysampler: Callable, raymarcher: Callable):
90+
"""
91+
Args:
92+
raysampler: A `Callable` that takes as input scene cameras
93+
(an instance of `CamerasBase`) and returns a `RayBundle` that
94+
describes the rays emitted from the cameras.
95+
raymarcher: A `Callable` that receives the response of the
96+
`volumetric_function` (an input to `self.forward`) evaluated
97+
along the sampled rays, and renders the rays with a
98+
ray-marching algorithm.
99+
"""
100+
super().__init__()
101+
102+
if not callable(raysampler):
103+
raise ValueError('"raysampler" has to be a "Callable" object.')
104+
if not callable(raymarcher):
105+
raise ValueError('"raymarcher" has to be a "Callable" object.')
106+
107+
self.raysampler = raysampler
108+
self.raymarcher = raymarcher
109+
110+
def forward(
111+
self, cameras: CamerasBase, volumetric_function: Callable, **kwargs
112+
) -> Tuple[torch.Tensor, RayBundle]:
113+
"""
114+
Render a batch of images using a volumetric function
115+
represented as a callable (e.g. a Pytorch module).
116+
117+
Args:
118+
cameras: A batch of cameras that render the scene. A `self.raysampler`
119+
takes the cameras as input and samples rays that pass through the
120+
domain of the volumentric function.
121+
volumetric_function: A `Callable` that accepts the parametrizations
122+
of the rendering rays and returns the densities and features
123+
at the respective 3D of the rendering rays. Please refer to
124+
the main class documentation for details.
125+
126+
Returns:
127+
images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)`
128+
containing the result of the rendering.
129+
ray_bundle: A `RayBundle` containing the parametrizations of the
130+
sampled rendering rays.
131+
"""
132+
133+
if not callable(volumetric_function):
134+
raise ValueError('"volumetric_function" has to be a "Callable" object.')
135+
136+
# first call the ray sampler that returns the RayBundle parametrizing
137+
# the rendering rays.
138+
ray_bundle = self.raysampler(
139+
cameras=cameras, volumetric_function=volumetric_function, **kwargs
140+
)
141+
# ray_bundle.origins - minibatch x ... x 3
142+
# ray_bundle.directions - minibatch x ... x 3
143+
# ray_bundle.lengths - minibatch x ... x n_pts_per_ray
144+
# ray_bundle.xys - minibatch x ... x 2
145+
146+
# given sampled rays, call the volumetric function that
147+
# evaluates the densities and features at the locations of the
148+
# ray points
149+
rays_densities, rays_features = volumetric_function(
150+
ray_bundle=ray_bundle, cameras=cameras, **kwargs
151+
)
152+
# ray_densities - minibatch x ... x n_pts_per_ray x density_dim
153+
# ray_features - minibatch x ... x n_pts_per_ray x feature_dim
154+
155+
# finally, march along the sampled rays to obtain the renders
156+
images = self.raymarcher(
157+
rays_densities=rays_densities,
158+
rays_features=rays_features,
159+
ray_bundle=ray_bundle,
160+
**kwargs
161+
)
162+
# images - minibatch x ... x (feature_dim + opacity_dim)
163+
164+
return images, ray_bundle
165+
166+
167+
# The volume renderer class should be initialized with a
168+
# function for raysampling and a function for raymarching.
169+
170+
# During the forward pass:
171+
# 1) The raysampler:
172+
# - samples rays from input cameras
173+
# - transforms the rays to world coordinates
174+
# 2) The scene volumes (which are an argument of the forward function)
175+
# are then sampled at the locations of the ray-points to generate
176+
# ray_densities and ray_features.
177+
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching
178+
# algorithm to render each ray.
179+
180+
181+
class VolumeRenderer(torch.nn.Module):
182+
"""
183+
A class for rendering a batch of Volumes. The class should
184+
be initialized with a raysampler and a raymarcher class which both have
185+
to be a `Callable`.
186+
"""
187+
188+
def __init__(
189+
self, raysampler: Callable, raymarcher: Callable, sample_mode: str = "bilinear"
190+
):
191+
"""
192+
Args:
193+
raysampler: A `Callable` that takes as input scene cameras
194+
(an instance of `CamerasBase`) and returns a `RayBundle` that
195+
describes the rays emitted from the cameras.
196+
raymarcher: A `Callable` that receives the `volumes`
197+
(an instance of `Volumes` input to `self.forward`)
198+
sampled at the ray-points, and renders the rays with a
199+
ray-marching algorithm.
200+
sample_mode: Defines the algorithm used to sample the volumetric
201+
voxel grid. Can be either "bilinear" or "nearest".
202+
"""
203+
super().__init__()
204+
205+
self.renderer = ImplicitRenderer(raysampler, raymarcher)
206+
self._sample_mode = sample_mode
207+
208+
def forward(
209+
self, cameras: CamerasBase, volumes: Volumes, **kwargs
210+
) -> Tuple[torch.Tensor, RayBundle]:
211+
"""
212+
Render a batch of images using raymarching over rays cast through
213+
input `Volumes`.
214+
215+
Args:
216+
cameras: A batch of cameras that render the scene. A `self.raysampler`
217+
takes the cameras as input and samples rays that pass through the
218+
domain of the volumentric function.
219+
volumes: An instance of the `Volumes` class representing a
220+
batch of volumes that are being rendered.
221+
222+
Returns:
223+
images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)`
224+
containing the result of the rendering.
225+
ray_bundle: A `RayBundle` containing the parametrizations of the
226+
sampled rendering rays.
227+
"""
228+
volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode)
229+
return self.renderer(
230+
cameras=cameras, volumetric_function=volumetric_function, **kwargs
231+
)
232+
233+
234+
class VolumeSampler(torch.nn.Module):
235+
"""
236+
A class that allows to sample a batch of volumes `Volumes`
237+
at 3D points sampled along projection rays.
238+
"""
239+
240+
def __init__(self, volumes: Volumes, sample_mode: str = "bilinear"):
241+
"""
242+
Args:
243+
volumes: An instance of the `Volumes` class representing a
244+
batch if volumes that are being rendered.
245+
sample_mode: Defines the algorithm used to sample the volumetric
246+
voxel grid. Can be either "bilinear" or "nearest".
247+
"""
248+
super().__init__()
249+
if not isinstance(volumes, Volumes):
250+
raise ValueError("'volumes' have to be an instance of the 'Volumes' class.")
251+
self._volumes = volumes
252+
self._sample_mode = sample_mode
253+
254+
def _get_ray_directions_transform(self):
255+
"""
256+
Compose the ray-directions transform by removing the translation component
257+
from the volume global-to-local coords transform.
258+
"""
259+
world2local = self._volumes.get_world_to_local_coords_transform().get_matrix()
260+
directions_transform_matrix = eyes(
261+
4,
262+
N=world2local.shape[0],
263+
device=world2local.device,
264+
dtype=world2local.dtype,
265+
)
266+
directions_transform_matrix[:, :3, :3] = world2local[:, :3, :3]
267+
directions_transform = Transform3d(matrix=directions_transform_matrix)
268+
return directions_transform
269+
270+
def forward(
271+
self, ray_bundle: RayBundle, **kwargs
272+
) -> Tuple[torch.Tensor, torch.Tensor]:
273+
"""
274+
Given an input ray parametrization, the forward function samples
275+
`self._volumes` at the respective 3D ray-points.
276+
277+
Args:
278+
ray_bundle: A RayBundle object with the following fields:
279+
rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the
280+
origins of the sampling rays in world coords.
281+
rays_directions_world: A tensor of shape `(minibatch, ..., 3)`
282+
containing the direction vectors of sampling rays in world coords.
283+
rays_lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
284+
containing the lengths at which the rays are sampled.
285+
286+
Returns:
287+
rays_densities: A tensor of shape
288+
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing the
289+
densitity vectors sampled from the volume at the locations of
290+
the ray points.
291+
rays_features: A tensor of shape
292+
`(minibatch, ..., num_points_per_ray, feature_dim)` containing the
293+
feature vectors sampled from the volume at the locations of
294+
the ray points.
295+
"""
296+
297+
# take out the interesting parts of ray_bundle
298+
rays_origins_world = ray_bundle.origins
299+
rays_directions_world = ray_bundle.directions
300+
rays_lengths = ray_bundle.lengths
301+
302+
# validate the inputs
303+
_validate_ray_bundle_variables(
304+
rays_origins_world, rays_directions_world, rays_lengths
305+
)
306+
if self._volumes.densities().shape[0] != rays_origins_world.shape[0]:
307+
raise ValueError("Input volumes have to have the same batch size as rays.")
308+
309+
#########################################################
310+
# 1) convert the origins/directions to the local coords #
311+
#########################################################
312+
313+
# origins are mapped with the world_to_local transform of the volumes
314+
rays_origins_local = self._volumes.world_to_local_coords(rays_origins_world)
315+
316+
# obtain the Transform3d object that transforms ray directions to local coords
317+
directions_transform = self._get_ray_directions_transform()
318+
319+
# transform the directions to the local coords
320+
rays_directions_local = directions_transform.transform_points(
321+
rays_directions_world.view(rays_lengths.shape[0], -1, 3)
322+
).view(rays_directions_world.shape)
323+
324+
############################
325+
# 2) obtain the ray points #
326+
############################
327+
328+
# this op produces a fairly big tensor (minibatch, ..., n_samples_per_ray, 3)
329+
rays_points_local = ray_bundle_variables_to_ray_points(
330+
rays_origins_local, rays_directions_local, rays_lengths
331+
)
332+
333+
########################
334+
# 3) sample the volume #
335+
########################
336+
337+
# generate the tensor for sampling
338+
volumes_densities = self._volumes.densities()
339+
dim_density = volumes_densities.shape[1]
340+
volumes_features = self._volumes.features()
341+
# adjust the volumes_features variable in case we have a feature-less volume
342+
if volumes_features is None:
343+
dim_feature = 0
344+
data_to_sample = volumes_densities
345+
else:
346+
dim_feature = volumes_features.shape[1]
347+
data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1)
348+
349+
# reshape to a size which grid_sample likes
350+
rays_points_local_flat = rays_points_local.view(
351+
rays_points_local.shape[0], -1, 1, 1, 3
352+
)
353+
354+
# run the grid sampler
355+
data_sampled = torch.nn.functional.grid_sample(
356+
data_to_sample,
357+
rays_points_local_flat,
358+
align_corners=True,
359+
mode=self._sample_mode,
360+
)
361+
362+
# permute the dimensions & reshape after sampling
363+
data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view(
364+
*rays_points_local.shape[:-1], data_sampled.shape[1]
365+
)
366+
367+
# split back to densities and features
368+
rays_densities, rays_features = data_sampled.split(
369+
[dim_density, dim_feature], dim=-1
370+
)
371+
372+
return rays_densities, rays_features

0 commit comments

Comments
 (0)