Skip to content

Commit 956d3a0

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Support for moving the renderer to a new device
Summary: Support for moving all the tensors of the renderer to another device by calling `renderer.to(new_device)` Currently the `MeshRenderer`, `MeshRasterizer` and `SoftPhongShader` (and other shaders) are all of type `nn.Module` which already supports easily moving tensors of submodules (defined as class attributes) to a different device. However the class attributes of the rasterizer and shader (e.g. cameras, lights, materials), are of type `TensorProperties`, not nn.Module so we need to explicity create a `to` method to move these tensors to device. Note that the `TensorProperties` class already has a `to` method so we only need to call `cameras.to(device)` and don't need to worry about the internal tensors. The other option is of course making these other classes (cameras, lights etc) also of type nn.Module. Reviewed By: gkioxari Differential Revision: D23885107 fbshipit-source-id: d71565c442181f739de4d797076ed5d00fb67f8e
1 parent b1eee57 commit 956d3a0

File tree

5 files changed

+107
-1
lines changed

5 files changed

+107
-1
lines changed

pytorch3d/renderer/blending.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
4646
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
4747

4848
if torch.is_tensor(blend_params.background_color):
49-
background_color = blend_params.background_color
49+
background_color = blend_params.background_color.to(device)
5050
else:
5151
background_color = colors.new_tensor(blend_params.background_color) # (3)
5252

@@ -163,6 +163,8 @@ def softmax_rgb_blend(
163163
background = blend_params.background_color
164164
if not torch.is_tensor(background):
165165
background = torch.tensor(background, dtype=torch.float32, device=device)
166+
else:
167+
background = background.to(device)
166168

167169
# Weight for background color
168170
eps = 1e-10

pytorch3d/renderer/mesh/rasterizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ def __init__(self, cameras=None, raster_settings=None):
7676
self.cameras = cameras
7777
self.raster_settings = raster_settings
7878

79+
def to(self, device):
80+
# Manually move to device cameras as it is not a subclass of nn.Module
81+
self.cameras = self.cameras.to(device)
82+
7983
def transform(self, meshes_world, **kwargs) -> torch.Tensor:
8084
"""
8185
Args:

pytorch3d/renderer/mesh/renderer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def __init__(self, rasterizer, shader):
3333
self.rasterizer = rasterizer
3434
self.shader = shader
3535

36+
def to(self, device):
37+
# Rasterizer and shader have submodules which are not of type nn.Module
38+
self.rasterizer.to(device)
39+
self.shader.to(device)
40+
3641
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
3742
"""
3843
Render a batch of images from a batch of meshes by rasterizing and then
@@ -44,6 +49,7 @@ def forward(self, meshes_world, **kwargs) -> torch.Tensor:
4449
face f, clipping is required before interpolating the texture uv
4550
coordinates and z buffer so that the colors and depths are limited to
4651
the range for the corresponding face.
52+
For this set rasterizer.raster_settings.clip_barycentric_coords=True
4753
"""
4854
fragments = self.rasterizer(meshes_world, **kwargs)
4955
images = self.shader(fragments, meshes_world, **kwargs)

pytorch3d/renderer/mesh/shader.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def __init__(
5050
self.cameras = cameras
5151
self.blend_params = blend_params if blend_params is not None else BlendParams()
5252

53+
def to(self, device):
54+
# Manually move to device modules which are not subclasses of nn.Module
55+
self.cameras = self.cameras.to(device)
56+
self.materials = self.materials.to(device)
57+
self.lights = self.lights.to(device)
58+
5359
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
5460
cameras = kwargs.get("cameras", self.cameras)
5561
if cameras is None:
@@ -98,6 +104,12 @@ def __init__(
98104
self.cameras = cameras
99105
self.blend_params = blend_params if blend_params is not None else BlendParams()
100106

107+
def to(self, device):
108+
# Manually move to device modules which are not subclasses of nn.Module
109+
self.cameras = self.cameras.to(device)
110+
self.materials = self.materials.to(device)
111+
self.lights = self.lights.to(device)
112+
101113
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
102114
cameras = kwargs.get("cameras", self.cameras)
103115
if cameras is None:
@@ -151,6 +163,12 @@ def __init__(
151163
self.cameras = cameras
152164
self.blend_params = blend_params if blend_params is not None else BlendParams()
153165

166+
def to(self, device):
167+
# Manually move to device modules which are not subclasses of nn.Module
168+
self.cameras = self.cameras.to(device)
169+
self.materials = self.materials.to(device)
170+
self.lights = self.lights.to(device)
171+
154172
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
155173
cameras = kwargs.get("cameras", self.cameras)
156174
if cameras is None:
@@ -203,6 +221,12 @@ def __init__(
203221
self.cameras = cameras
204222
self.blend_params = blend_params if blend_params is not None else BlendParams()
205223

224+
def to(self, device):
225+
# Manually move to device modules which are not subclasses of nn.Module
226+
self.cameras = self.cameras.to(device)
227+
self.materials = self.materials.to(device)
228+
self.lights = self.lights.to(device)
229+
206230
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
207231
cameras = kwargs.get("cameras", self.cameras)
208232
if cameras is None:
@@ -272,6 +296,12 @@ def __init__(
272296
self.cameras = cameras
273297
self.blend_params = blend_params if blend_params is not None else BlendParams()
274298

299+
def to(self, device):
300+
# Manually move to device modules which are not subclasses of nn.Module
301+
self.cameras = self.cameras.to(device)
302+
self.materials = self.materials.to(device)
303+
self.lights = self.lights.to(device)
304+
275305
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
276306
cameras = kwargs.get("cameras", self.cameras)
277307
if cameras is None:

tests/test_render_meshes.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,3 +1042,67 @@ def test_simple_sphere_outside_zfar(self):
10421042
)
10431043

10441044
self.assertClose(rgb, image_ref, atol=0.05)
1045+
1046+
def test_to(self):
1047+
# Test moving all the tensors in the renderer to a new device
1048+
# to support multigpu rendering.
1049+
device1 = torch.device("cpu")
1050+
1051+
R, T = look_at_view_transform(1500, 0.0, 0.0)
1052+
1053+
# Init shader settings
1054+
materials = Materials(device=device1)
1055+
lights = PointLights(device=device1)
1056+
lights.location = torch.tensor([0.0, 0.0, +1000.0], device=device1)[None]
1057+
1058+
raster_settings = RasterizationSettings(
1059+
image_size=256, blur_radius=0.0, faces_per_pixel=1
1060+
)
1061+
cameras = FoVPerspectiveCameras(
1062+
device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100
1063+
)
1064+
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
1065+
1066+
blend_params = BlendParams(
1067+
1e-4,
1068+
1e-4,
1069+
background_color=torch.zeros(3, dtype=torch.float32, device=device1),
1070+
)
1071+
1072+
shader = SoftPhongShader(
1073+
lights=lights,
1074+
cameras=cameras,
1075+
materials=materials,
1076+
blend_params=blend_params,
1077+
)
1078+
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
1079+
1080+
def _check_props_on_device(renderer, device):
1081+
self.assertEqual(renderer.rasterizer.cameras.device, device)
1082+
self.assertEqual(renderer.shader.cameras.device, device)
1083+
self.assertEqual(renderer.shader.lights.device, device)
1084+
self.assertEqual(renderer.shader.lights.ambient_color.device, device)
1085+
self.assertEqual(renderer.shader.materials.device, device)
1086+
self.assertEqual(renderer.shader.materials.ambient_color.device, device)
1087+
1088+
mesh = ico_sphere(2, device1)
1089+
verts_padded = mesh.verts_padded()
1090+
textures = TexturesVertex(
1091+
verts_features=torch.ones_like(verts_padded, device=device1)
1092+
)
1093+
mesh.textures = textures
1094+
_check_props_on_device(renderer, device1)
1095+
1096+
# Test rendering on cpu
1097+
output_images = renderer(mesh)
1098+
self.assertEqual(output_images.device, device1)
1099+
1100+
# Move renderer and mesh to another device and re render
1101+
# This also tests that background_color is correctly moved to
1102+
# the new device
1103+
device2 = torch.device("cuda:0")
1104+
renderer.to(device2)
1105+
mesh = mesh.to(device2)
1106+
_check_props_on_device(renderer, device2)
1107+
output_images = renderer(mesh)
1108+
self.assertEqual(output_images.device, device2)

0 commit comments

Comments
 (0)