|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +import warnings |
| 3 | +from typing import Optional, Tuple, Union |
| 4 | + |
| 5 | +import torch |
| 6 | + |
| 7 | + |
| 8 | +class EmissionAbsorptionRaymarcher(torch.nn.Module): |
| 9 | + """ |
| 10 | + Raymarch using the Emission-Absorption (EA) algorithm. |
| 11 | +
|
| 12 | + The algorithm independently renders each ray by analyzing density and |
| 13 | + feature values sampled at (typically uniformly) spaced 3D locations along |
| 14 | + each ray. The density values `rays_densities` are of shape |
| 15 | + `(..., n_points_per_ray)`, their values should range between [0, 1], and |
| 16 | + represent the opaqueness of each point (the higher the less transparent). |
| 17 | + The feature values `rays_features` of shape |
| 18 | + `(..., n_points_per_ray, feature_dim)` represent the content of the |
| 19 | + point that is supposed to be rendered in case the given point is opaque |
| 20 | + (i.e. its density -> 1.0). |
| 21 | +
|
| 22 | + EA first utilizes `rays_densities` to compute the absorption function |
| 23 | + along each ray as follows: |
| 24 | + ``` |
| 25 | + absorption = cumprod(1 - rays_densities, dim=-1) |
| 26 | + ``` |
| 27 | + The value of absorption at position `absorption[..., k]` specifies |
| 28 | + how much light has reached `k`-th point along a ray since starting |
| 29 | + its trajectory at `k=0`-th point. |
| 30 | +
|
| 31 | + Each ray is then rendered into a tensor `features` of shape `(..., feature_dim)` |
| 32 | + by taking a weighed combination of per-ray features `rays_features` as follows: |
| 33 | + ``` |
| 34 | + weights = absorption * rays_densities |
| 35 | + features = (rays_features * weights).sum(dim=-2) |
| 36 | + ``` |
| 37 | + Where `weights` denote a function that has a strong peak around the location |
| 38 | + of the first surface point that a given ray passes through. |
| 39 | +
|
| 40 | + Note that for a perfectly bounded volume (with a strictly binary density), |
| 41 | + the `weights = cumprod(1 - rays_densities, dim=-1) * rays_densities` |
| 42 | + function would yield 0 everywhere. In order to prevent this, |
| 43 | + the result of the cumulative product is shifted `self.surface_thickness` |
| 44 | + elements along the ray direction. |
| 45 | + """ |
| 46 | + |
| 47 | + def __init__(self, surface_thickness: int = 1): |
| 48 | + """ |
| 49 | + Args: |
| 50 | + surface_thickness: Denotes the overlap between the absorption |
| 51 | + function and the density function. |
| 52 | + """ |
| 53 | + super().__init__() |
| 54 | + self.surface_thickness = surface_thickness |
| 55 | + |
| 56 | + def forward( |
| 57 | + self, |
| 58 | + rays_densities: torch.Tensor, |
| 59 | + rays_features: torch.Tensor, |
| 60 | + eps: float = 1e-10, |
| 61 | + **kwargs, |
| 62 | + ) -> torch.Tensor: |
| 63 | + """ |
| 64 | + Args: |
| 65 | + rays_densities: Per-ray density values represented with a tensor |
| 66 | + of shape `(..., n_points_per_ray, 1)` whose values range in [0, 1]. |
| 67 | + rays_features: Per-ray feature values represented with a tensor |
| 68 | + of shape `(..., n_points_per_ray, feature_dim)`. |
| 69 | + eps: A lower bound added to `rays_densities` before computing |
| 70 | + the absorbtion function (cumprod of `1-rays_densities` along |
| 71 | + each ray). This prevents the cumprod to yield exact 0 |
| 72 | + which would inhibit any gradient-based learning. |
| 73 | +
|
| 74 | + Returns: |
| 75 | + features_opacities: A tensor of shape `(..., feature_dim+1)` |
| 76 | + that concatenates two tensors alonng the last dimension: |
| 77 | + 1) features: A tensor of per-ray renders |
| 78 | + of shape `(..., feature_dim)`. |
| 79 | + 2) opacities: A tensor of per-ray opacity values |
| 80 | + of shape `(..., 1)`. Its values range between [0, 1] and |
| 81 | + denote the total amount of light that has been absorbed |
| 82 | + for each ray. E.g. a value of 0 corresponds to the ray |
| 83 | + completely passing through a volume. Please refer to the |
| 84 | + `AbsorptionOnlyRaymarcher` documentation for the |
| 85 | + explanation of the algorithm that computes `opacities`. |
| 86 | + """ |
| 87 | + _check_raymarcher_inputs( |
| 88 | + rays_densities, |
| 89 | + rays_features, |
| 90 | + None, |
| 91 | + z_can_be_none=True, |
| 92 | + features_can_be_none=False, |
| 93 | + density_1d=True, |
| 94 | + ) |
| 95 | + _check_density_bounds(rays_densities) |
| 96 | + rays_densities = rays_densities[..., 0] |
| 97 | + absorption = _shifted_cumprod( |
| 98 | + (1.0 + eps) - rays_densities, shift=self.surface_thickness |
| 99 | + ) |
| 100 | + weights = rays_densities * absorption |
| 101 | + features = (weights[..., None] * rays_features).sum(dim=-2) |
| 102 | + opacities = 1.0 - torch.prod(1.0 - rays_densities, dim=-1, keepdim=True) |
| 103 | + |
| 104 | + return torch.cat((features, opacities), dim=-1) |
| 105 | + |
| 106 | + |
| 107 | +class AbsorptionOnlyRaymarcher(torch.nn.Module): |
| 108 | + """ |
| 109 | + Raymarch using the Absorption-Only (AO) algorithm. |
| 110 | +
|
| 111 | + The algorithm independently renders each ray by analyzing density and |
| 112 | + feature values sampled at (typically uniformly) spaced 3D locations along |
| 113 | + each ray. The density values `rays_densities` are of shape |
| 114 | + `(..., n_points_per_ray, 1)`, their values should range between [0, 1], and |
| 115 | + represent the opaqueness of each point (the higher the less transparent). |
| 116 | + The algorithm only measures the total amount of light absorbed along each ray |
| 117 | + and, besides outputting per-ray `opacity` values of shape `(...,)`, |
| 118 | + does not produce any feature renderings. |
| 119 | +
|
| 120 | + The algorithm simply computes `total_transmission = prod(1 - rays_densities)` |
| 121 | + of shape `(..., 1)` which, for each ray, measures the total amount of light |
| 122 | + that passed through the volume. |
| 123 | + It then returns `opacities = 1 - total_transmission`. |
| 124 | + """ |
| 125 | + |
| 126 | + def __init__(self): |
| 127 | + super().__init__() |
| 128 | + |
| 129 | + def forward( |
| 130 | + self, rays_densities: torch.Tensor, **kwargs |
| 131 | + ) -> Union[None, torch.Tensor]: |
| 132 | + """ |
| 133 | + Args: |
| 134 | + rays_densities: Per-ray density values represented with a tensor |
| 135 | + of shape `(..., n_points_per_ray)` whose values range in [0, 1]. |
| 136 | +
|
| 137 | + Returns: |
| 138 | + opacities: A tensor of per-ray opacity values of shape `(..., 1)`. |
| 139 | + Its values range between [0, 1] and denote the total amount |
| 140 | + of light that has been absorbed for each ray. E.g. a value |
| 141 | + of 0 corresponds to the ray completely passing through a volume. |
| 142 | + """ |
| 143 | + |
| 144 | + _check_raymarcher_inputs( |
| 145 | + rays_densities, |
| 146 | + None, |
| 147 | + None, |
| 148 | + features_can_be_none=True, |
| 149 | + z_can_be_none=True, |
| 150 | + density_1d=True, |
| 151 | + ) |
| 152 | + rays_densities = rays_densities[..., 0] |
| 153 | + _check_density_bounds(rays_densities) |
| 154 | + total_transmission = torch.prod(1 - rays_densities, dim=-1, keepdim=True) |
| 155 | + opacities = 1.0 - total_transmission |
| 156 | + return opacities |
| 157 | + |
| 158 | + |
| 159 | +def _shifted_cumprod(x, shift=1): |
| 160 | + """ |
| 161 | + Computes `torch.cumprod(x, dim=-1)` and prepends `shift` number of |
| 162 | + ones and removes `shift` trailing elements to/from the last dimension |
| 163 | + of the result. |
| 164 | + """ |
| 165 | + x_cumprod = torch.cumprod(x, dim=-1) |
| 166 | + x_cumprod_shift = torch.cat( |
| 167 | + [torch.ones_like(x_cumprod[..., :shift]), x_cumprod[..., :-shift]], dim=-1 |
| 168 | + ) |
| 169 | + return x_cumprod_shift |
| 170 | + |
| 171 | + |
| 172 | +def _check_density_bounds( |
| 173 | + rays_densities: torch.Tensor, bounds: Tuple[float, float] = (0.0, 1.0) |
| 174 | +): |
| 175 | + """ |
| 176 | + Checks whether the elements of `rays_densities` range within `bounds`. |
| 177 | + If not issues a warning. |
| 178 | + """ |
| 179 | + if ((rays_densities > bounds[1]) | (rays_densities < bounds[0])).any(): |
| 180 | + warnings.warn( |
| 181 | + "One or more elements of rays_densities are outside of valid" |
| 182 | + + f"range {str(bounds)}" |
| 183 | + ) |
| 184 | + |
| 185 | + |
| 186 | +def _check_raymarcher_inputs( |
| 187 | + rays_densities: torch.Tensor, |
| 188 | + rays_features: Optional[torch.Tensor], |
| 189 | + rays_z: Optional[torch.Tensor], |
| 190 | + features_can_be_none: bool = False, |
| 191 | + z_can_be_none: bool = False, |
| 192 | + density_1d: bool = True, |
| 193 | +): |
| 194 | + """ |
| 195 | + Checks the validity of the inputs to raymarching algorithms. |
| 196 | + """ |
| 197 | + if not torch.is_tensor(rays_densities): |
| 198 | + raise ValueError("rays_densities has to be an instance of torch.Tensor.") |
| 199 | + |
| 200 | + if not z_can_be_none and not torch.is_tensor(rays_z): |
| 201 | + raise ValueError("rays_z has to be an instance of torch.Tensor.") |
| 202 | + |
| 203 | + if not features_can_be_none and not torch.is_tensor(rays_features): |
| 204 | + raise ValueError("rays_features has to be an instance of torch.Tensor.") |
| 205 | + |
| 206 | + if rays_densities.ndim < 1: |
| 207 | + raise ValueError("rays_densities have to have at least one dimension.") |
| 208 | + |
| 209 | + if density_1d and rays_densities.shape[-1] != 1: |
| 210 | + raise ValueError( |
| 211 | + "The size of the last dimension of rays_densities has to be one." |
| 212 | + ) |
| 213 | + |
| 214 | + rays_shape = rays_densities.shape[:-1] |
| 215 | + |
| 216 | + if not z_can_be_none and rays_z.shape != rays_shape: |
| 217 | + raise ValueError("rays_z have to be of the same shape as rays_densities.") |
| 218 | + |
| 219 | + if not features_can_be_none and rays_features.shape[:-1] != rays_shape: |
| 220 | + raise ValueError( |
| 221 | + "The first to previous to last dimensions of rays_features" |
| 222 | + " have to be the same as all dimensions of rays_densities." |
| 223 | + ) |
0 commit comments