Skip to content

Commit 1af1a36

Browse files
davnov134facebook-github-bot
authored andcommitted
Raymarching
Summary: Implements two basic raymarchers. Reviewed By: gkioxari Differential Revision: D24064250 fbshipit-source-id: 18071bd039995336b7410caa403ea29fafb5c66f
1 parent aa9bcaf commit 1af1a36

File tree

5 files changed

+445
-0
lines changed

5 files changed

+445
-0
lines changed

pytorch3d/renderer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
look_at_rotation,
2121
look_at_view_transform,
2222
)
23+
from .implicit import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
2324
from .lighting import DirectionalLights, PointLights, diffuse, specular
2425
from .materials import Materials
2526
from .mesh import (
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
4+
5+
6+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
import warnings
3+
from typing import Optional, Tuple, Union
4+
5+
import torch
6+
7+
8+
class EmissionAbsorptionRaymarcher(torch.nn.Module):
9+
"""
10+
Raymarch using the Emission-Absorption (EA) algorithm.
11+
12+
The algorithm independently renders each ray by analyzing density and
13+
feature values sampled at (typically uniformly) spaced 3D locations along
14+
each ray. The density values `rays_densities` are of shape
15+
`(..., n_points_per_ray)`, their values should range between [0, 1], and
16+
represent the opaqueness of each point (the higher the less transparent).
17+
The feature values `rays_features` of shape
18+
`(..., n_points_per_ray, feature_dim)` represent the content of the
19+
point that is supposed to be rendered in case the given point is opaque
20+
(i.e. its density -> 1.0).
21+
22+
EA first utilizes `rays_densities` to compute the absorption function
23+
along each ray as follows:
24+
```
25+
absorption = cumprod(1 - rays_densities, dim=-1)
26+
```
27+
The value of absorption at position `absorption[..., k]` specifies
28+
how much light has reached `k`-th point along a ray since starting
29+
its trajectory at `k=0`-th point.
30+
31+
Each ray is then rendered into a tensor `features` of shape `(..., feature_dim)`
32+
by taking a weighed combination of per-ray features `rays_features` as follows:
33+
```
34+
weights = absorption * rays_densities
35+
features = (rays_features * weights).sum(dim=-2)
36+
```
37+
Where `weights` denote a function that has a strong peak around the location
38+
of the first surface point that a given ray passes through.
39+
40+
Note that for a perfectly bounded volume (with a strictly binary density),
41+
the `weights = cumprod(1 - rays_densities, dim=-1) * rays_densities`
42+
function would yield 0 everywhere. In order to prevent this,
43+
the result of the cumulative product is shifted `self.surface_thickness`
44+
elements along the ray direction.
45+
"""
46+
47+
def __init__(self, surface_thickness: int = 1):
48+
"""
49+
Args:
50+
surface_thickness: Denotes the overlap between the absorption
51+
function and the density function.
52+
"""
53+
super().__init__()
54+
self.surface_thickness = surface_thickness
55+
56+
def forward(
57+
self,
58+
rays_densities: torch.Tensor,
59+
rays_features: torch.Tensor,
60+
eps: float = 1e-10,
61+
**kwargs,
62+
) -> torch.Tensor:
63+
"""
64+
Args:
65+
rays_densities: Per-ray density values represented with a tensor
66+
of shape `(..., n_points_per_ray, 1)` whose values range in [0, 1].
67+
rays_features: Per-ray feature values represented with a tensor
68+
of shape `(..., n_points_per_ray, feature_dim)`.
69+
eps: A lower bound added to `rays_densities` before computing
70+
the absorbtion function (cumprod of `1-rays_densities` along
71+
each ray). This prevents the cumprod to yield exact 0
72+
which would inhibit any gradient-based learning.
73+
74+
Returns:
75+
features_opacities: A tensor of shape `(..., feature_dim+1)`
76+
that concatenates two tensors alonng the last dimension:
77+
1) features: A tensor of per-ray renders
78+
of shape `(..., feature_dim)`.
79+
2) opacities: A tensor of per-ray opacity values
80+
of shape `(..., 1)`. Its values range between [0, 1] and
81+
denote the total amount of light that has been absorbed
82+
for each ray. E.g. a value of 0 corresponds to the ray
83+
completely passing through a volume. Please refer to the
84+
`AbsorptionOnlyRaymarcher` documentation for the
85+
explanation of the algorithm that computes `opacities`.
86+
"""
87+
_check_raymarcher_inputs(
88+
rays_densities,
89+
rays_features,
90+
None,
91+
z_can_be_none=True,
92+
features_can_be_none=False,
93+
density_1d=True,
94+
)
95+
_check_density_bounds(rays_densities)
96+
rays_densities = rays_densities[..., 0]
97+
absorption = _shifted_cumprod(
98+
(1.0 + eps) - rays_densities, shift=self.surface_thickness
99+
)
100+
weights = rays_densities * absorption
101+
features = (weights[..., None] * rays_features).sum(dim=-2)
102+
opacities = 1.0 - torch.prod(1.0 - rays_densities, dim=-1, keepdim=True)
103+
104+
return torch.cat((features, opacities), dim=-1)
105+
106+
107+
class AbsorptionOnlyRaymarcher(torch.nn.Module):
108+
"""
109+
Raymarch using the Absorption-Only (AO) algorithm.
110+
111+
The algorithm independently renders each ray by analyzing density and
112+
feature values sampled at (typically uniformly) spaced 3D locations along
113+
each ray. The density values `rays_densities` are of shape
114+
`(..., n_points_per_ray, 1)`, their values should range between [0, 1], and
115+
represent the opaqueness of each point (the higher the less transparent).
116+
The algorithm only measures the total amount of light absorbed along each ray
117+
and, besides outputting per-ray `opacity` values of shape `(...,)`,
118+
does not produce any feature renderings.
119+
120+
The algorithm simply computes `total_transmission = prod(1 - rays_densities)`
121+
of shape `(..., 1)` which, for each ray, measures the total amount of light
122+
that passed through the volume.
123+
It then returns `opacities = 1 - total_transmission`.
124+
"""
125+
126+
def __init__(self):
127+
super().__init__()
128+
129+
def forward(
130+
self, rays_densities: torch.Tensor, **kwargs
131+
) -> Union[None, torch.Tensor]:
132+
"""
133+
Args:
134+
rays_densities: Per-ray density values represented with a tensor
135+
of shape `(..., n_points_per_ray)` whose values range in [0, 1].
136+
137+
Returns:
138+
opacities: A tensor of per-ray opacity values of shape `(..., 1)`.
139+
Its values range between [0, 1] and denote the total amount
140+
of light that has been absorbed for each ray. E.g. a value
141+
of 0 corresponds to the ray completely passing through a volume.
142+
"""
143+
144+
_check_raymarcher_inputs(
145+
rays_densities,
146+
None,
147+
None,
148+
features_can_be_none=True,
149+
z_can_be_none=True,
150+
density_1d=True,
151+
)
152+
rays_densities = rays_densities[..., 0]
153+
_check_density_bounds(rays_densities)
154+
total_transmission = torch.prod(1 - rays_densities, dim=-1, keepdim=True)
155+
opacities = 1.0 - total_transmission
156+
return opacities
157+
158+
159+
def _shifted_cumprod(x, shift=1):
160+
"""
161+
Computes `torch.cumprod(x, dim=-1)` and prepends `shift` number of
162+
ones and removes `shift` trailing elements to/from the last dimension
163+
of the result.
164+
"""
165+
x_cumprod = torch.cumprod(x, dim=-1)
166+
x_cumprod_shift = torch.cat(
167+
[torch.ones_like(x_cumprod[..., :shift]), x_cumprod[..., :-shift]], dim=-1
168+
)
169+
return x_cumprod_shift
170+
171+
172+
def _check_density_bounds(
173+
rays_densities: torch.Tensor, bounds: Tuple[float, float] = (0.0, 1.0)
174+
):
175+
"""
176+
Checks whether the elements of `rays_densities` range within `bounds`.
177+
If not issues a warning.
178+
"""
179+
if ((rays_densities > bounds[1]) | (rays_densities < bounds[0])).any():
180+
warnings.warn(
181+
"One or more elements of rays_densities are outside of valid"
182+
+ f"range {str(bounds)}"
183+
)
184+
185+
186+
def _check_raymarcher_inputs(
187+
rays_densities: torch.Tensor,
188+
rays_features: Optional[torch.Tensor],
189+
rays_z: Optional[torch.Tensor],
190+
features_can_be_none: bool = False,
191+
z_can_be_none: bool = False,
192+
density_1d: bool = True,
193+
):
194+
"""
195+
Checks the validity of the inputs to raymarching algorithms.
196+
"""
197+
if not torch.is_tensor(rays_densities):
198+
raise ValueError("rays_densities has to be an instance of torch.Tensor.")
199+
200+
if not z_can_be_none and not torch.is_tensor(rays_z):
201+
raise ValueError("rays_z has to be an instance of torch.Tensor.")
202+
203+
if not features_can_be_none and not torch.is_tensor(rays_features):
204+
raise ValueError("rays_features has to be an instance of torch.Tensor.")
205+
206+
if rays_densities.ndim < 1:
207+
raise ValueError("rays_densities have to have at least one dimension.")
208+
209+
if density_1d and rays_densities.shape[-1] != 1:
210+
raise ValueError(
211+
"The size of the last dimension of rays_densities has to be one."
212+
)
213+
214+
rays_shape = rays_densities.shape[:-1]
215+
216+
if not z_can_be_none and rays_z.shape != rays_shape:
217+
raise ValueError("rays_z have to be of the same shape as rays_densities.")
218+
219+
if not features_can_be_none and rays_features.shape[:-1] != rays_shape:
220+
raise ValueError(
221+
"The first to previous to last dimensions of rays_features"
222+
" have to be the same as all dimensions of rays_densities."
223+
)

tests/bm_raymarching.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
import itertools
4+
5+
from fvcore.common.benchmark import benchmark
6+
from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
7+
from test_raymarching import TestRaymarching
8+
9+
10+
def bm_raymarching() -> None:
11+
case_grid = {
12+
"raymarcher_type": [EmissionAbsorptionRaymarcher, AbsorptionOnlyRaymarcher],
13+
"n_rays": [10, 1000, 10000],
14+
"n_pts_per_ray": [10, 1000, 10000],
15+
}
16+
test_cases = itertools.product(*case_grid.values())
17+
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
18+
19+
benchmark(TestRaymarching.raymarcher, "RAYMARCHER", kwargs_list, warmup_iters=1)

0 commit comments

Comments
 (0)