Skip to content

Commit ebe2693

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Support variable size radius for points in rasterizer
Summary: Support variable size pointclouds in the renderer API to allow compatibility with Pulsar rasterizer. If radius is provided as a float, it is converted to a tensor of shape (P). Otherwise radius is expected to be an (N, P_padded) dimensional tensor where P_padded is the max number of points in the batch (following the convention from pulsar: https://our.intern.facebook.com/intern/diffusion/FBS/browse/master/fbcode/frl/gemini/pulsar/pulsar/renderer.py?commit=ee0342850210e5df441e14fd97162675c70d147c&lines=50) Reviewed By: jcjohnson, gkioxari Differential Revision: D21429400 fbshipit-source-id: 65de7d9cd2472b27fc29f96160c33687e88098a2
1 parent e40c216 commit ebe2693

File tree

8 files changed

+290
-72
lines changed

8 files changed

+290
-72
lines changed

pytorch3d/csrc/rasterize_points/rasterize_points.cu

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ __device__ void CheckPixelInsidePoint(
3838
float& q_max_z,
3939
int& q_max_idx,
4040
PointQ& q,
41-
const float radius2,
41+
const float* radius,
4242
const float xf,
4343
const float yf,
4444
const int K) {
4545
const float px = points[p_idx * 3 + 0];
4646
const float py = points[p_idx * 3 + 1];
4747
const float pz = points[p_idx * 3 + 2];
48+
const float p_radius = radius[p_idx];
49+
const float radius2 = p_radius * p_radius;
4850
if (pz < 0)
4951
return; // Don't render points behind the camera
5052
const float dx = xf - px;
@@ -81,7 +83,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
8183
const float* points, // (P, 3)
8284
const int64_t* cloud_to_packed_first_idx, // (N)
8385
const int64_t* num_points_per_cloud, // (N)
84-
const float radius,
86+
const float* radius,
8587
const int N,
8688
const int S,
8789
const int K,
@@ -91,7 +93,6 @@ __global__ void RasterizePointsNaiveCudaKernel(
9193
// Simple version: One thread per output pixel
9294
const int num_threads = gridDim.x * blockDim.x;
9395
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
94-
const float radius2 = radius * radius;
9596
for (int i = tid; i < N * S * S; i += num_threads) {
9697
// Convert linear index to 3D index
9798
const int n = i / (S * S); // Batch index
@@ -128,7 +129,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
128129

129130
for (int p_idx = point_start_idx; p_idx < point_stop_idx; ++p_idx) {
130131
CheckPixelInsidePoint(
131-
points, p_idx, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K);
132+
points, p_idx, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K);
132133
}
133134
BubbleSort(q, q_size);
134135
int idx = n * S * S * K + pix_idx * K;
@@ -145,7 +146,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
145146
const at::Tensor& cloud_to_packed_first_idx, // (N)
146147
const at::Tensor& num_points_per_cloud, // (N)
147148
const int image_size,
148-
const float radius,
149+
const at::Tensor& radius,
149150
const int points_per_pixel) {
150151
// Check inputs are on the same device
151152
at::TensorArg points_t{points, "points", 1},
@@ -194,7 +195,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
194195
points.contiguous().data_ptr<float>(),
195196
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
196197
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
197-
radius,
198+
radius.contiguous().data_ptr<float>(),
198199
N,
199200
S,
200201
K,
@@ -214,7 +215,7 @@ __global__ void RasterizePointsCoarseCudaKernel(
214215
const float* points, // (P, 3)
215216
const int64_t* cloud_to_packed_first_idx, // (N)
216217
const int64_t* num_points_per_cloud, // (N)
217-
const float radius,
218+
const float* radius,
218219
const int N,
219220
const int P,
220221
const int S,
@@ -266,12 +267,13 @@ __global__ void RasterizePointsCoarseCudaKernel(
266267
const float px = points[p_idx * 3 + 0];
267268
const float py = points[p_idx * 3 + 1];
268269
const float pz = points[p_idx * 3 + 2];
270+
const float p_radius = radius[p_idx];
269271
if (pz < 0)
270272
continue; // Don't render points behind the camera.
271-
const float px0 = px - radius;
272-
const float px1 = px + radius;
273-
const float py0 = py - radius;
274-
const float py1 = py + radius;
273+
const float px0 = px - p_radius;
274+
const float px1 = px + p_radius;
275+
const float py0 = py - p_radius;
276+
const float py1 = py + p_radius;
275277

276278
// Brute-force search over all bins; TODO something smarter?
277279
// For example we could compute the exact bin where the point falls,
@@ -341,7 +343,7 @@ at::Tensor RasterizePointsCoarseCuda(
341343
const at::Tensor& cloud_to_packed_first_idx, // (N)
342344
const at::Tensor& num_points_per_cloud, // (N)
343345
const int image_size,
344-
const float radius,
346+
const at::Tensor& radius,
345347
const int bin_size,
346348
const int max_points_per_bin) {
347349
TORCH_CHECK(
@@ -390,7 +392,7 @@ at::Tensor RasterizePointsCoarseCuda(
390392
points.contiguous().data_ptr<float>(),
391393
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
392394
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
393-
radius,
395+
radius.contiguous().data_ptr<float>(),
394396
N,
395397
P,
396398
image_size,
@@ -411,7 +413,7 @@ at::Tensor RasterizePointsCoarseCuda(
411413
__global__ void RasterizePointsFineCudaKernel(
412414
const float* points, // (P, 3)
413415
const int32_t* bin_points, // (N, B, B, T)
414-
const float radius,
416+
const float* radius,
415417
const int bin_size,
416418
const int N,
417419
const int B, // num_bins
@@ -425,7 +427,6 @@ __global__ void RasterizePointsFineCudaKernel(
425427
const int num_pixels = N * B * B * bin_size * bin_size;
426428
const int num_threads = gridDim.x * blockDim.x;
427429
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
428-
const float radius2 = radius * radius;
429430

430431
for (int pid = tid; pid < num_pixels; pid += num_threads) {
431432
// Convert linear index into bin and pixel indices. We make the within
@@ -464,7 +465,7 @@ __global__ void RasterizePointsFineCudaKernel(
464465
continue;
465466
}
466467
CheckPixelInsidePoint(
467-
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K);
468+
points, p, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K);
468469
}
469470
// Now we've looked at all the points for this bin, so we can write
470471
// output for the current pixel.
@@ -488,7 +489,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
488489
const at::Tensor& points, // (P, 3)
489490
const at::Tensor& bin_points,
490491
const int image_size,
491-
const float radius,
492+
const at::Tensor& radius,
492493
const int bin_size,
493494
const int points_per_pixel) {
494495
// Check inputs are on the same device
@@ -525,7 +526,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
525526
RasterizePointsFineCudaKernel<<<blocks, threads, 0, stream>>>(
526527
points.contiguous().data_ptr<float>(),
527528
bin_points.contiguous().data_ptr<int32_t>(),
528-
radius,
529+
radius.contiguous().data_ptr<float>(),
529530
bin_size,
530531
N,
531532
B,

pytorch3d/csrc/rasterize_points/rasterize_points.h

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
1515
const torch::Tensor& cloud_to_packed_first_idx,
1616
const torch::Tensor& num_points_per_cloud,
1717
const int image_size,
18-
const float radius,
18+
const torch::Tensor& radius,
1919
const int points_per_pixel);
2020

2121
#ifdef WITH_CUDA
@@ -25,7 +25,7 @@ RasterizePointsNaiveCuda(
2525
const torch::Tensor& cloud_to_packed_first_idx,
2626
const torch::Tensor& num_points_per_cloud,
2727
const int image_size,
28-
const float radius,
28+
const torch::Tensor& radius,
2929
const int points_per_pixel);
3030
#endif
3131
// Naive (forward) pointcloud rasterization: For each pixel, for each point,
@@ -41,7 +41,8 @@ RasterizePointsNaiveCuda(
4141
// in the batch where N is the batch size.
4242
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
4343
// for each pointcloud in the batch.
44-
// radius: Radius of each point (in NDC units)
44+
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
45+
// each point in points.
4546
// image_size: (S) Size of the image to return (in pixels)
4647
// points_per_pixel: (K) The number closest of points to return for each pixel
4748
//
@@ -62,14 +63,15 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
6263
const torch::Tensor& cloud_to_packed_first_idx,
6364
const torch::Tensor& num_points_per_cloud,
6465
const int image_size,
65-
const float radius,
66+
const torch::Tensor& radius,
6667
const int points_per_pixel) {
6768
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
6869
num_points_per_cloud.is_cuda()) {
6970
#ifdef WITH_CUDA
7071
CHECK_CUDA(points);
7172
CHECK_CUDA(cloud_to_packed_first_idx);
7273
CHECK_CUDA(num_points_per_cloud);
74+
CHECK_CUDA(radius);
7375
return RasterizePointsNaiveCuda(
7476
points,
7577
cloud_to_packed_first_idx,
@@ -100,7 +102,7 @@ torch::Tensor RasterizePointsCoarseCpu(
100102
const torch::Tensor& cloud_to_packed_first_idx,
101103
const torch::Tensor& num_points_per_cloud,
102104
const int image_size,
103-
const float radius,
105+
const torch::Tensor& radius,
104106
const int bin_size,
105107
const int max_points_per_bin);
106108

@@ -110,7 +112,7 @@ torch::Tensor RasterizePointsCoarseCuda(
110112
const torch::Tensor& cloud_to_packed_first_idx,
111113
const torch::Tensor& num_points_per_cloud,
112114
const int image_size,
113-
const float radius,
115+
const torch::Tensor& radius,
114116
const int bin_size,
115117
const int max_points_per_bin);
116118
#endif
@@ -124,7 +126,8 @@ torch::Tensor RasterizePointsCoarseCuda(
124126
// in the batch where N is the batch size.
125127
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
126128
// for each pointcloud in the batch.
127-
// radius: Radius of points to rasterize (in NDC units)
129+
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
130+
// each point in points.
128131
// image_size: Size of the image to generate (in pixels)
129132
// bin_size: Size of each bin within the image (in pixels)
130133
//
@@ -138,7 +141,7 @@ torch::Tensor RasterizePointsCoarse(
138141
const torch::Tensor& cloud_to_packed_first_idx,
139142
const torch::Tensor& num_points_per_cloud,
140143
const int image_size,
141-
const float radius,
144+
const torch::Tensor& radius,
142145
const int bin_size,
143146
const int max_points_per_bin) {
144147
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
@@ -147,6 +150,7 @@ torch::Tensor RasterizePointsCoarse(
147150
CHECK_CUDA(points);
148151
CHECK_CUDA(cloud_to_packed_first_idx);
149152
CHECK_CUDA(num_points_per_cloud);
153+
CHECK_CUDA(radius);
150154
return RasterizePointsCoarseCuda(
151155
points,
152156
cloud_to_packed_first_idx,
@@ -179,7 +183,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
179183
const torch::Tensor& points,
180184
const torch::Tensor& bin_points,
181185
const int image_size,
182-
const float radius,
186+
const torch::Tensor& radius,
183187
const int bin_size,
184188
const int points_per_pixel);
185189
#endif
@@ -191,7 +195,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
191195
// bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points
192196
// that fall into each bin (output from coarse rasterization)
193197
// image_size: Size of image to generate (in pixels)
194-
// radius: Radius of points to rasterize (NDC units)
198+
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
199+
// each point in points.
195200
// bin_size: Size of each bin (in pixels)
196201
// points_per_pixel: How many points to rasterize for each pixel
197202
//
@@ -210,7 +215,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
210215
const torch::Tensor& points,
211216
const torch::Tensor& bin_points,
212217
const int image_size,
213-
const float radius,
218+
const torch::Tensor& radius,
214219
const int bin_size,
215220
const int points_per_pixel) {
216221
if (points.is_cuda()) {
@@ -296,7 +301,8 @@ torch::Tensor RasterizePointsBackward(
296301
// in the batch where N is the batch size.
297302
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
298303
// for each pointcloud in the batch.
299-
// radius: Radius of each point (in NDC units)
304+
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
305+
// each point in points.
300306
// image_size: (S) Size of the image to return (in pixels)
301307
// points_per_pixel: (K) The number of points to return for each pixel
302308
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
@@ -320,7 +326,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
320326
const torch::Tensor& cloud_to_packed_first_idx,
321327
const torch::Tensor& num_points_per_cloud,
322328
const int image_size,
323-
const float radius,
329+
const torch::Tensor& radius,
324330
const int points_per_pixel,
325331
const int bin_size,
326332
const int max_points_per_bin) {

pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
1717
const torch::Tensor& cloud_to_packed_first_idx, // (N)
1818
const torch::Tensor& num_points_per_cloud, // (N)
1919
const int image_size,
20-
const float radius,
20+
const torch::Tensor& radius,
2121
const int points_per_pixel) {
2222
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
2323

@@ -35,8 +35,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
3535
auto point_idxs_a = point_idxs.accessor<int32_t, 4>();
3636
auto zbuf_a = zbuf.accessor<float, 4>();
3737
auto pix_dists_a = pix_dists.accessor<float, 4>();
38+
auto radius_a = radius.accessor<float, 1>();
3839

39-
const float radius2 = radius * radius;
4040
for (int n = 0; n < N; ++n) {
4141
// Loop through each pointcloud in the batch.
4242
// Get the start index of the points in points_packed and the num points
@@ -63,6 +63,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
6363
const float px = points_a[p][0];
6464
const float py = points_a[p][1];
6565
const float pz = points_a[p][2];
66+
const float p_radius = radius_a[p];
67+
const float radius2 = p_radius * p_radius;
6668
if (pz < 0) {
6769
continue;
6870
}
@@ -98,7 +100,7 @@ torch::Tensor RasterizePointsCoarseCpu(
98100
const torch::Tensor& cloud_to_packed_first_idx, // (N)
99101
const torch::Tensor& num_points_per_cloud, // (N)
100102
const int image_size,
101-
const float radius,
103+
const torch::Tensor& radius,
102104
const int bin_size,
103105
const int max_points_per_bin) {
104106
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
@@ -112,6 +114,7 @@ torch::Tensor RasterizePointsCoarseCpu(
112114
auto points_a = points.accessor<float, 2>();
113115
auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>();
114116
auto bin_points_a = bin_points.accessor<int32_t, 4>();
117+
auto radius_a = radius.accessor<float, 1>();
115118

116119
const float pixel_width = 2.0f / image_size;
117120
const float bin_width = pixel_width * bin_size;
@@ -140,13 +143,14 @@ torch::Tensor RasterizePointsCoarseCpu(
140143
float px = points_a[p][0];
141144
float py = points_a[p][1];
142145
float pz = points_a[p][2];
146+
const float p_radius = radius_a[p];
143147
if (pz < 0) {
144148
continue;
145149
}
146-
float point_x_min = px - radius;
147-
float point_x_max = px + radius;
148-
float point_y_min = py - radius;
149-
float point_y_max = py + radius;
150+
float point_x_min = px - p_radius;
151+
float point_x_max = px + p_radius;
152+
float point_y_min = py - p_radius;
153+
float point_y_max = py + p_radius;
150154

151155
// Use a half-open interval so that points exactly on the
152156
// boundary between bins will fall into exactly one bin.

0 commit comments

Comments
 (0)