@@ -38,13 +38,15 @@ __device__ void CheckPixelInsidePoint(
38
38
float & q_max_z,
39
39
int & q_max_idx,
40
40
PointQ& q,
41
- const float radius2 ,
41
+ const float * radius ,
42
42
const float xf,
43
43
const float yf,
44
44
const int K) {
45
45
const float px = points[p_idx * 3 + 0 ];
46
46
const float py = points[p_idx * 3 + 1 ];
47
47
const float pz = points[p_idx * 3 + 2 ];
48
+ const float p_radius = radius[p_idx];
49
+ const float radius2 = p_radius * p_radius;
48
50
if (pz < 0 )
49
51
return ; // Don't render points behind the camera
50
52
const float dx = xf - px;
@@ -81,7 +83,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
81
83
const float * points, // (P, 3)
82
84
const int64_t * cloud_to_packed_first_idx, // (N)
83
85
const int64_t * num_points_per_cloud, // (N)
84
- const float radius,
86
+ const float * radius,
85
87
const int N,
86
88
const int S,
87
89
const int K,
@@ -91,7 +93,6 @@ __global__ void RasterizePointsNaiveCudaKernel(
91
93
// Simple version: One thread per output pixel
92
94
const int num_threads = gridDim .x * blockDim .x ;
93
95
const int tid = blockDim .x * blockIdx .x + threadIdx .x ;
94
- const float radius2 = radius * radius;
95
96
for (int i = tid; i < N * S * S; i += num_threads) {
96
97
// Convert linear index to 3D index
97
98
const int n = i / (S * S); // Batch index
@@ -128,7 +129,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
128
129
129
130
for (int p_idx = point_start_idx; p_idx < point_stop_idx; ++p_idx) {
130
131
CheckPixelInsidePoint (
131
- points, p_idx, q_size, q_max_z, q_max_idx, q, radius2 , xf, yf, K);
132
+ points, p_idx, q_size, q_max_z, q_max_idx, q, radius , xf, yf, K);
132
133
}
133
134
BubbleSort (q, q_size);
134
135
int idx = n * S * S * K + pix_idx * K;
@@ -145,7 +146,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
145
146
const at::Tensor& cloud_to_packed_first_idx, // (N)
146
147
const at::Tensor& num_points_per_cloud, // (N)
147
148
const int image_size,
148
- const float radius,
149
+ const at::Tensor& radius,
149
150
const int points_per_pixel) {
150
151
// Check inputs are on the same device
151
152
at::TensorArg points_t {points, " points" , 1 },
@@ -194,7 +195,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
194
195
points.contiguous ().data_ptr <float >(),
195
196
cloud_to_packed_first_idx.contiguous ().data_ptr <int64_t >(),
196
197
num_points_per_cloud.contiguous ().data_ptr <int64_t >(),
197
- radius,
198
+ radius. contiguous (). data_ptr < float >() ,
198
199
N,
199
200
S,
200
201
K,
@@ -214,7 +215,7 @@ __global__ void RasterizePointsCoarseCudaKernel(
214
215
const float * points, // (P, 3)
215
216
const int64_t * cloud_to_packed_first_idx, // (N)
216
217
const int64_t * num_points_per_cloud, // (N)
217
- const float radius,
218
+ const float * radius,
218
219
const int N,
219
220
const int P,
220
221
const int S,
@@ -266,12 +267,13 @@ __global__ void RasterizePointsCoarseCudaKernel(
266
267
const float px = points[p_idx * 3 + 0 ];
267
268
const float py = points[p_idx * 3 + 1 ];
268
269
const float pz = points[p_idx * 3 + 2 ];
270
+ const float p_radius = radius[p_idx];
269
271
if (pz < 0 )
270
272
continue ; // Don't render points behind the camera.
271
- const float px0 = px - radius ;
272
- const float px1 = px + radius ;
273
- const float py0 = py - radius ;
274
- const float py1 = py + radius ;
273
+ const float px0 = px - p_radius ;
274
+ const float px1 = px + p_radius ;
275
+ const float py0 = py - p_radius ;
276
+ const float py1 = py + p_radius ;
275
277
276
278
// Brute-force search over all bins; TODO something smarter?
277
279
// For example we could compute the exact bin where the point falls,
@@ -341,7 +343,7 @@ at::Tensor RasterizePointsCoarseCuda(
341
343
const at::Tensor& cloud_to_packed_first_idx, // (N)
342
344
const at::Tensor& num_points_per_cloud, // (N)
343
345
const int image_size,
344
- const float radius,
346
+ const at::Tensor& radius,
345
347
const int bin_size,
346
348
const int max_points_per_bin) {
347
349
TORCH_CHECK (
@@ -390,7 +392,7 @@ at::Tensor RasterizePointsCoarseCuda(
390
392
points.contiguous ().data_ptr <float >(),
391
393
cloud_to_packed_first_idx.contiguous ().data_ptr <int64_t >(),
392
394
num_points_per_cloud.contiguous ().data_ptr <int64_t >(),
393
- radius,
395
+ radius. contiguous (). data_ptr < float >() ,
394
396
N,
395
397
P,
396
398
image_size,
@@ -411,7 +413,7 @@ at::Tensor RasterizePointsCoarseCuda(
411
413
__global__ void RasterizePointsFineCudaKernel (
412
414
const float * points, // (P, 3)
413
415
const int32_t * bin_points, // (N, B, B, T)
414
- const float radius,
416
+ const float * radius,
415
417
const int bin_size,
416
418
const int N,
417
419
const int B, // num_bins
@@ -425,7 +427,6 @@ __global__ void RasterizePointsFineCudaKernel(
425
427
const int num_pixels = N * B * B * bin_size * bin_size;
426
428
const int num_threads = gridDim .x * blockDim .x ;
427
429
const int tid = blockIdx .x * blockDim .x + threadIdx .x ;
428
- const float radius2 = radius * radius;
429
430
430
431
for (int pid = tid; pid < num_pixels; pid += num_threads) {
431
432
// Convert linear index into bin and pixel indices. We make the within
@@ -464,7 +465,7 @@ __global__ void RasterizePointsFineCudaKernel(
464
465
continue ;
465
466
}
466
467
CheckPixelInsidePoint (
467
- points, p, q_size, q_max_z, q_max_idx, q, radius2 , xf, yf, K);
468
+ points, p, q_size, q_max_z, q_max_idx, q, radius , xf, yf, K);
468
469
}
469
470
// Now we've looked at all the points for this bin, so we can write
470
471
// output for the current pixel.
@@ -488,7 +489,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
488
489
const at::Tensor& points, // (P, 3)
489
490
const at::Tensor& bin_points,
490
491
const int image_size,
491
- const float radius,
492
+ const at::Tensor& radius,
492
493
const int bin_size,
493
494
const int points_per_pixel) {
494
495
// Check inputs are on the same device
@@ -525,7 +526,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
525
526
RasterizePointsFineCudaKernel<<<blocks, threads, 0 , stream>>> (
526
527
points.contiguous ().data_ptr <float >(),
527
528
bin_points.contiguous ().data_ptr <int32_t >(),
528
- radius,
529
+ radius. contiguous (). data_ptr < float >() ,
529
530
bin_size,
530
531
N,
531
532
B,
0 commit comments