Skip to content

Commit 3d769a6

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Non Square image rasterization for pointclouds
Summary: Similar to non square image rasterization for meshes, apply the same updates to the pointcloud rasterizer. Main API Change: - PointRasterizationSettings now accepts a tuple/list of (H, W) for the image size. Reviewed By: jcjohnson Differential Revision: D25465206 fbshipit-source-id: 7370d83c431af1b972158cecae19d82364623380
1 parent 569e522 commit 3d769a6

22 files changed

+713
-264
lines changed

pytorch3d/csrc/compositing/alpha_composite.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ __global__ void alphaCompositeCudaForwardKernel(
3030
// Get the batch and index
3131
const int batch = blockIdx.x;
3232

33-
const int num_pixels = C * W * H;
33+
const int num_pixels = C * H * W;
3434
const int num_threads = gridDim.y * blockDim.x;
3535
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
3636

3737
// Iterate over each feature in each pixel
3838
for (int pid = tid; pid < num_pixels; pid += num_threads) {
39-
int ch = pid / (W * H);
40-
int j = (pid % (W * H)) / H;
41-
int i = (pid % (W * H)) % H;
39+
int ch = pid / (H * W);
40+
int j = (pid % (H * W)) / W;
41+
int i = (pid % (H * W)) % W;
4242

4343
// alphacomposite the different values
4444
float cum_alpha = 1.;
@@ -81,16 +81,16 @@ __global__ void alphaCompositeCudaBackwardKernel(
8181
// Get the batch and index
8282
const int batch = blockIdx.x;
8383

84-
const int num_pixels = C * W * H;
84+
const int num_pixels = C * H * W;
8585
const int num_threads = gridDim.y * blockDim.x;
8686
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
8787

8888
// Parallelize over each feature in each pixel in images of size H * W,
8989
// for each image in the batch of size batch_size
9090
for (int pid = tid; pid < num_pixels; pid += num_threads) {
91-
int ch = pid / (W * H);
92-
int j = (pid % (W * H)) / H;
93-
int i = (pid % (W * H)) % H;
91+
int ch = pid / (H * W);
92+
int j = (pid % (H * W)) / W;
93+
int i = (pid % (H * W)) % W;
9494

9595
// alphacomposite the different values
9696
float cum_alpha = 1.;

pytorch3d/csrc/compositing/alpha_composite.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
// features: FloatTensor of shape (C, P) which gives the features
1212
// of each point where C is the size of the feature and
1313
// P the number of points.
14-
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
14+
// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where
1515
// points_per_pixel is the number of points in the z-buffer
16-
// sorted in z-order, and W is the image size.
17-
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
16+
// sorted in z-order, and (H, W) is the image size.
17+
// points_idx: IntTensor of shape (N, points_per_pixel, H, W) giving the
1818
// indices of the nearest points at each pixel, sorted in z-order.
1919
// Returns:
20-
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
20+
// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated
2121
// feature for each point. Concretely, it gives:
2222
// weighted_fs[b,c,i,j] = sum_k cum_alpha_k *
2323
// features[c,points_idx[b,k,i,j]]

pytorch3d/csrc/compositing/norm_weighted_sum.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,16 @@ __global__ void weightedSumNormCudaForwardKernel(
3030
// Get the batch and index
3131
const int batch = blockIdx.x;
3232

33-
const int num_pixels = C * W * H;
33+
const int num_pixels = C * H * W;
3434
const int num_threads = gridDim.y * blockDim.x;
3535
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
3636

3737
// Parallelize over each feature in each pixel in images of size H * W,
3838
// for each image in the batch of size batch_size
3939
for (int pid = tid; pid < num_pixels; pid += num_threads) {
40-
int ch = pid / (W * H);
41-
int j = (pid % (W * H)) / H;
42-
int i = (pid % (W * H)) % H;
40+
int ch = pid / (H * W);
41+
int j = (pid % (H * W)) / W;
42+
int i = (pid % (H * W)) % W;
4343

4444
// Store the accumulated alpha value
4545
float cum_alpha = 0.;
@@ -101,9 +101,9 @@ __global__ void weightedSumNormCudaBackwardKernel(
101101
// Parallelize over each feature in each pixel in images of size H * W,
102102
// for each image in the batch of size batch_size
103103
for (int pid = tid; pid < num_pixels; pid += num_threads) {
104-
int ch = pid / (W * H);
105-
int j = (pid % (W * H)) / H;
106-
int i = (pid % (W * H)) % H;
104+
int ch = pid / (H * W);
105+
int j = (pid % (H * W)) / W;
106+
int i = (pid % (H * W)) % W;
107107

108108
float sum_alpha = 0.;
109109
float sum_alphafs = 0.;

pytorch3d/csrc/compositing/norm_weighted_sum.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
// features: FloatTensor of shape (C, P) which gives the features
1212
// of each point where C is the size of the feature and
1313
// P the number of points.
14-
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
14+
// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where
1515
// points_per_pixel is the number of points in the z-buffer
16-
// sorted in z-order, and W is the image size.
17-
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
16+
// sorted in z-order, and (H, W) is the image size.
17+
// points_idx: IntTensor of shape (N, points_per_pixel, H, W) giving the
1818
// indices of the nearest points at each pixel, sorted in z-order.
1919
// Returns:
20-
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
20+
// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated
2121
// feature in each point. Concretely, it gives:
2222
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
2323
// features[c,points_idx[b,k,i,j]] / sum_k alphas[b,k,i,j]

pytorch3d/csrc/compositing/weighted_sum.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@ __global__ void weightedSumCudaForwardKernel(
2828
// Get the batch and index
2929
const int batch = blockIdx.x;
3030

31-
const int num_pixels = C * W * H;
31+
const int num_pixels = C * H * W;
3232
const int num_threads = gridDim.y * blockDim.x;
3333
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
3434

3535
// Parallelize over each feature in each pixel in images of size H * W,
3636
// for each image in the batch of size batch_size
3737
for (int pid = tid; pid < num_pixels; pid += num_threads) {
38-
int ch = pid / (W * H);
39-
int j = (pid % (W * H)) / H;
40-
int i = (pid % (W * H)) % H;
38+
int ch = pid / (H * W);
39+
int j = (pid % (H * W)) / W;
40+
int i = (pid % (H * W)) % W;
4141

4242
// Iterate through the closest K points for this pixel
4343
for (int k = 0; k < points_idx.size(1); ++k) {
@@ -76,16 +76,16 @@ __global__ void weightedSumCudaBackwardKernel(
7676
// Get the batch and index
7777
const int batch = blockIdx.x;
7878

79-
const int num_pixels = C * W * H;
79+
const int num_pixels = C * H * W;
8080
const int num_threads = gridDim.y * blockDim.x;
8181
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
8282

8383
// Iterate over each pixel to compute the contribution to the
8484
// gradient for the features and weights
8585
for (int pid = tid; pid < num_pixels; pid += num_threads) {
86-
int ch = pid / (W * H);
87-
int j = (pid % (W * H)) / H;
88-
int i = (pid % (W * H)) % H;
86+
int ch = pid / (H * W);
87+
int j = (pid % (H * W)) / W;
88+
int i = (pid % (H * W)) % W;
8989

9090
// Iterate through the closest K points for this pixel
9191
for (int k = 0; k < points_idx.size(1); ++k) {

pytorch3d/csrc/compositing/weighted_sum.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
// features: FloatTensor of shape (C, P) which gives the features
1212
// of each point where C is the size of the feature and
1313
// P the number of points.
14-
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
14+
// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where
1515
// points_per_pixel is the number of points in the z-buffer
16-
// sorted in z-order, and W is the image size.
16+
// sorted in z-order, and (H, W) is the image size.
1717
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
1818
// indices of the nearest points at each pixel, sorted in z-order.
1919
// Returns:
20-
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
20+
// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated
2121
// feature in each point. Concretely, it gives:
2222
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
2323
// features[c,points_idx[b,k,i,j]]

pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,6 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
452452
const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f;
453453
const float sign = inside ? -1.0f : 1.0f;
454454

455-
// TODO(T52813608) Add support for non-square images.
456455
auto grad_dist_f = PointTriangleDistanceBackward(
457456
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
458457
const float2 ddist_d_v0 = thrust::get<1>(grad_dist_f);
@@ -606,7 +605,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
606605
const float half_pix_x = NDC_x_half_range / W;
607606
const float half_pix_y = NDC_y_half_range / H;
608607

609-
// This is a boolean array of shape (num_bins, num_bins, chunk_size)
608+
// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
610609
// stored in shared memory that will track whether each point in the chunk
611610
// falls into each bin of the image.
612611
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
@@ -755,7 +754,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
755754
const int num_bins_y = 1 + (H - 1) / bin_size;
756755
const int num_bins_x = 1 + (W - 1) / bin_size;
757756

758-
if (num_bins_y >= kMaxFacesPerBin || num_bins_x >= kMaxFacesPerBin) {
757+
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
759758
std::stringstream ss;
760759
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
761760
<< ", num_bins_x: " << num_bins_x << ", "
@@ -800,7 +799,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
800799
// ****************************************************************************
801800
__global__ void RasterizeMeshesFineCudaKernel(
802801
const float* face_verts, // (F, 3, 3)
803-
const int32_t* bin_faces, // (N, B, B, T)
802+
const int32_t* bin_faces, // (N, BH, BW, T)
804803
const float blur_radius,
805804
const int bin_size,
806805
const bool perspective_correct,
@@ -813,12 +812,12 @@ __global__ void RasterizeMeshesFineCudaKernel(
813812
const int H,
814813
const int W,
815814
const int K,
816-
int64_t* face_idxs, // (N, S, S, K)
817-
float* zbuf, // (N, S, S, K)
818-
float* pix_dists, // (N, S, S, K)
819-
float* bary // (N, S, S, K, 3)
815+
int64_t* face_idxs, // (N, H, W, K)
816+
float* zbuf, // (N, H, W, K)
817+
float* pix_dists, // (N, H, W, K)
818+
float* bary // (N, H, W, K, 3)
820819
) {
821-
// This can be more than S^2 if S % bin_size != 0
820+
// This can be more than H * W if H or W are not divisible by bin_size.
822821
int num_pixels = N * BH * BW * bin_size * bin_size;
823822
int num_threads = gridDim.x * blockDim.x;
824823
int tid = blockIdx.x * blockDim.x + threadIdx.x;

pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,11 @@
55
#include <list>
66
#include <queue>
77
#include <tuple>
8+
#include "rasterize_points/rasterization_utils.h"
89
#include "utils/geometry_utils.h"
910
#include "utils/vec2.h"
1011
#include "utils/vec3.h"
1112

12-
// The default value of the NDC range is [-1, 1], however in the case that
13-
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
14-
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
15-
// the NDC range is calculated and S2 is the other image dimension.
16-
// e.g. to get the NDC x range S1 = W and S2 = H
17-
float NonSquareNdcRange(int S1, int S2) {
18-
float range = 2.0f;
19-
if (S1 > S2) {
20-
range = ((S1 / S2) * range);
21-
}
22-
return range;
23-
}
24-
25-
// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device
26-
// coordinates. We divide the NDC range into S1 evenly-sized
27-
// pixels, and assume that each pixel falls in the *center* of its range.
28-
// The default value of the NDC range is [-1, 1], however in the case that
29-
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
30-
// the longer side is scaled by the ratio of H:W. The dimension of i should be
31-
// S1 and the other image dimension is S2 For example, to get the x and y NDC
32-
// coordinates or a given pixel i:
33-
// x = PixToNonSquareNdc(i, W, H)
34-
// y = PixToNonSquareNdc(i, H, W)
35-
float PixToNonSquareNdc(int i, int S1, int S2) {
36-
float range = NonSquareNdcRange(S1, S2);
37-
// NDC: offset + (i * pixel_width + half_pixel_width)
38-
// The NDC range is [-range/2, range/2].
39-
const float offset = (range / 2.0f);
40-
return -offset + (range * i + offset) / S1;
41-
}
42-
4313
// Get (x, y, z) values for vertex from (3, 3) tensor face.
4414
template <typename Face>
4515
auto ExtractVerts(const Face& face, const int vertex_index) {

pytorch3d/csrc/rasterize_points/rasterization_utils.cuh

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,6 @@
22

33
#pragma once
44

5-
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
6-
// coordinates in the range [-1, 1]. We divide the NDC range into S evenly-sized
7-
// pixels, and assume that each pixel falls in the *center* of its range.
8-
// TODO: delete this function after updating the pointcloud rasterizer to
9-
// support non square images.
10-
__device__ inline float PixToNdc(int i, int S) {
11-
// NDC: x-offset + (i * pixel_width + half_pixel_width)
12-
return -1.0 + (2 * i + 1.0) / S;
13-
}
14-
155
// The default value of the NDC range is [-1, 1], however in the case that
166
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
177
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
@@ -50,7 +40,7 @@ __device__ inline float PixToNonSquareNdc(int i, int S1, int S2) {
5040
// TODO: is 8 enough? Would increasing have performance considerations?
5141
const int32_t kMaxPointsPerPixel = 150;
5242

53-
const int32_t kMaxFacesPerBin = 22;
43+
const int32_t kMaxItemsPerBin = 22;
5444

5545
template <typename T>
5646
__device__ inline void BubbleSort(T* arr, int n) {
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
#pragma once
4+
5+
// The default value of the NDC range is [-1, 1], however in the case that
6+
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
7+
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
8+
// the NDC range is calculated and S2 is the other image dimension.
9+
// e.g. to get the NDC x range S1 = W and S2 = H
10+
inline float NonSquareNdcRange(int S1, int S2) {
11+
float range = 2.0f;
12+
if (S1 > S2) {
13+
range = ((S1 / S2) * range);
14+
}
15+
return range;
16+
}
17+
18+
// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device
19+
// coordinates. We divide the NDC range into S1 evenly-sized
20+
// pixels, and assume that each pixel falls in the *center* of its range.
21+
// The default value of the NDC range is [-1, 1], however in the case that
22+
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
23+
// the longer side is scaled by the ratio of H:W. The dimension of i should be
24+
// S1 and the other image dimension is S2 For example, to get the x and y NDC
25+
// coordinates or a given pixel i:
26+
// x = PixToNonSquareNdc(i, W, H)
27+
// y = PixToNonSquareNdc(i, H, W)
28+
inline float PixToNonSquareNdc(int i, int S1, int S2) {
29+
float range = NonSquareNdcRange(S1, S2);
30+
// NDC: offset + (i * pixel_width + half_pixel_width)
31+
// The NDC range is [-range/2, range/2].
32+
const float offset = (range / 2.0f);
33+
return -offset + (range * i + offset) / S1;
34+
}

0 commit comments

Comments
 (0)