@@ -271,6 +271,44 @@ __global__ void elementwise_kernel(int N, func_t f) {
271
271
}
272
272
}
273
273
274
+ #ifdef USE_ROCM
275
+ template <int nt, int vt, typename func_t >
276
+ C10_LAUNCH_BOUNDS_2 (nt, 4 )
277
+ __global__ void elementwise_kernel_manual_unroll(int N, func_t f) {
278
+ int tid = threadIdx .x ;
279
+ int nv = nt * vt;
280
+ int idx = nv * blockIdx .x + tid;
281
+ if ((idx + nt*(vt-1 )) < N) {
282
+ f (idx, true );
283
+ } else {
284
+ #pragma unroll
285
+ for (int i = 0 ; i < vt; i++) {
286
+ if (idx < N) {
287
+ f (idx, false );
288
+ idx += nt;
289
+ }
290
+ }
291
+ }
292
+ }
293
+
294
+ template <int nt, int vt, typename func_t >
295
+ C10_LAUNCH_BOUNDS_2 (nt, 4 )
296
+ __global__ void elementwise_kernel_strided(int N, func_t f) {
297
+ int tid = threadIdx .x ;
298
+ int idx = nt * vt * blockIdx .x + tid;
299
+ int step = nt * vt * gridDim .x ;
300
+ while (idx < N) {
301
+ #pragma unroll
302
+ for (int i = 0 ; i < vt; i++) {
303
+ if ((idx + nt * i) < N) {
304
+ f (idx + nt * i);
305
+ }
306
+ }
307
+ idx += step;
308
+ }
309
+ }
310
+ #endif
311
+
274
312
template <int nt, int vt, typename func_t >
275
313
static void launch_legacy_kernel (int64_t N, const func_t & f) {
276
314
TORCH_INTERNAL_ASSERT (N >= 0 && N <= std::numeric_limits<int32_t >::max ());
@@ -284,6 +322,37 @@ static void launch_legacy_kernel(int64_t N, const func_t& f) {
284
322
C10_CUDA_KERNEL_LAUNCH_CHECK ();
285
323
}
286
324
325
+ #ifdef USE_ROCM
326
+ template <int nt, int vt, typename func_t >
327
+ static void launch_legacy_kernel_manual_unroll (int64_t N, const func_t & f) {
328
+ TORCH_INTERNAL_ASSERT (N >= 0 && N <= std::numeric_limits<int32_t >::max ());
329
+ if (N == 0 ) {
330
+ return ;
331
+ }
332
+ dim3 block (nt);
333
+ dim3 grid ((N + block.x * vt - 1 ) / (block.x * vt));
334
+ auto stream = at::cuda::getCurrentCUDAStream ();
335
+ elementwise_kernel_manual_unroll<nt, vt, func_t ><<<grid, block, 0 , stream>>> (N, f);
336
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
337
+ }
338
+
339
+ template <int nt, int vt, typename func_t >
340
+ static void launch_legacy_kernel_strided (int64_t N, const func_t & f) {
341
+ TORCH_INTERNAL_ASSERT (N >= 0 && N <= std::numeric_limits<int32_t >::max ());
342
+ if (N == 0 ) {
343
+ return ;
344
+ }
345
+ dim3 block (nt);
346
+ dim3 grid (8192 );
347
+ auto stream = at::cuda::getCurrentCUDAStream ();
348
+ int ub_idx = nt * vt;
349
+ ub_idx = ub_idx * (grid.x - 1 ) +(block.x - 1 );
350
+ ub_idx = ub_idx + nt*vt;
351
+ elementwise_kernel_strided<nt, vt, func_t ><<<grid, block, 0 , stream>>> (N, f);
352
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
353
+ }
354
+ #endif
355
+
287
356
template <typename traits, typename func_t , typename index_t , size_t ... INDEX>
288
357
C10_HOST_DEVICE typename traits::result_type invoke_impl (
289
358
const func_t & f,
@@ -362,12 +431,84 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
362
431
return launch_vectorized_kernel (numel, f, data);
363
432
}
364
433
auto offset_calc = ::make_offset_calculator<traits::arity + 1 >(iter);
434
+ #ifndef USE_ROCM
365
435
constexpr int unroll_factor = sizeof (arg0_t ) >= 4 ? 2 : 4 ;
366
436
launch_legacy_kernel<128 , unroll_factor>(numel, [=] GPU_LAMBDA (int idx) {
367
437
auto offsets = offset_calc.get (idx);
368
438
arg0_t * out = (arg0_t *)(data[0 ] + offsets[0 ]);
369
439
*out = invoke (f, &data[1 ], &offsets[1 ], 1 );
370
440
});
441
+ #else
442
+ constexpr int unroll_factor = sizeof (arg0_t ) >= 4 ? 4 : 8 ;
443
+ constexpr int grp_sz = 128 ;
444
+ launch_legacy_kernel_manual_unroll<grp_sz, unroll_factor>(numel, [=] GPU_LAMBDA (int idx, bool unrl4x) {
445
+ if constexpr (unroll_factor == 4 ) {
446
+ if (unrl4x) {
447
+ auto offsets0 = offset_calc.get (idx);
448
+ auto offsets1 = offset_calc.get (idx+grp_sz);
449
+ auto offsets2 = offset_calc.get (idx+grp_sz*2 );
450
+ auto offsets3 = offset_calc.get (idx+grp_sz*3 );
451
+ arg0_t * out0 = (arg0_t *)(data[0 ] + offsets0[0 ]);
452
+ arg0_t * out1 = (arg0_t *)(data[0 ] + offsets1[0 ]);
453
+ arg0_t * out2 = (arg0_t *)(data[0 ] + offsets2[0 ]);
454
+ arg0_t * out3 = (arg0_t *)(data[0 ] + offsets3[0 ]);
455
+ auto tmp0 = invoke (f, &data[1 ], &offsets0[1 ], 1 );
456
+ auto tmp1 = invoke (f, &data[1 ], &offsets1[1 ], 1 );
457
+ auto tmp2 = invoke (f, &data[1 ], &offsets2[1 ], 1 );
458
+ auto tmp3 = invoke (f, &data[1 ], &offsets3[1 ], 1 );
459
+ *out0 = tmp0;
460
+ *out1 = tmp1;
461
+ *out2 = tmp2;
462
+ *out3 = tmp3;
463
+ }
464
+ else {
465
+ auto offsets = offset_calc.get (idx);
466
+ arg0_t * out = (arg0_t *)(data[0 ] + offsets[0 ]);
467
+ *out = invoke (f, &data[1 ], &offsets[1 ], 1 );
468
+ }
469
+ } else {
470
+ if (unrl4x) {
471
+ auto offsets0 = offset_calc.get (idx);
472
+ auto offsets1 = offset_calc.get (idx+grp_sz);
473
+ auto offsets2 = offset_calc.get (idx+grp_sz*2 );
474
+ auto offsets3 = offset_calc.get (idx+grp_sz*3 );
475
+ auto offsets4 = offset_calc.get (idx+grp_sz*4 );
476
+ auto offsets5 = offset_calc.get (idx+grp_sz*5 );
477
+ auto offsets6 = offset_calc.get (idx+grp_sz*6 );
478
+ auto offsets7 = offset_calc.get (idx+grp_sz*7 );
479
+ arg0_t * out0 = (arg0_t *)(data[0 ] + offsets0[0 ]);
480
+ arg0_t * out1 = (arg0_t *)(data[0 ] + offsets1[0 ]);
481
+ arg0_t * out2 = (arg0_t *)(data[0 ] + offsets2[0 ]);
482
+ arg0_t * out3 = (arg0_t *)(data[0 ] + offsets3[0 ]);
483
+ arg0_t * out4 = (arg0_t *)(data[0 ] + offsets4[0 ]);
484
+ arg0_t * out5 = (arg0_t *)(data[0 ] + offsets5[0 ]);
485
+ arg0_t * out6 = (arg0_t *)(data[0 ] + offsets6[0 ]);
486
+ arg0_t * out7 = (arg0_t *)(data[0 ] + offsets7[0 ]);
487
+ auto tmp0 = invoke (f, &data[1 ], &offsets0[1 ], 1 );
488
+ auto tmp1 = invoke (f, &data[1 ], &offsets1[1 ], 1 );
489
+ auto tmp2 = invoke (f, &data[1 ], &offsets2[1 ], 1 );
490
+ auto tmp3 = invoke (f, &data[1 ], &offsets3[1 ], 1 );
491
+ auto tmp4 = invoke (f, &data[1 ], &offsets4[1 ], 1 );
492
+ auto tmp5 = invoke (f, &data[1 ], &offsets5[1 ], 1 );
493
+ auto tmp6 = invoke (f, &data[1 ], &offsets6[1 ], 1 );
494
+ auto tmp7 = invoke (f, &data[1 ], &offsets7[1 ], 1 );
495
+ *out0 = tmp0;
496
+ *out1 = tmp1;
497
+ *out2 = tmp2;
498
+ *out3 = tmp3;
499
+ *out4 = tmp4;
500
+ *out5 = tmp5;
501
+ *out6 = tmp6;
502
+ *out7 = tmp7;
503
+ }
504
+ else {
505
+ auto offsets = offset_calc.get (idx);
506
+ arg0_t * out = (arg0_t *)(data[0 ] + offsets[0 ]);
507
+ *out = invoke (f, &data[1 ], &offsets[1 ], 1 );
508
+ }
509
+ }
510
+ });
511
+ #endif
371
512
}
372
513
373
514
template <typename func_t >
@@ -401,7 +542,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
401
542
dtypes[i] = iter.dtype (i);
402
543
strides[i] = inner_strides[i];
403
544
}
404
- launch_legacy_kernel <512 , 1 >(numel, [=]GPU_LAMBDA (int idx) {
545
+ launch_legacy_kernel_strided <512 , 4 >(numel, [=]GPU_LAMBDA (int idx) {
405
546
void * out = data[0 ] + strides[0 ] * idx;
406
547
arg0_t result = invoke (f, &data[1 ], &strides[1 ], &dtypes[1 ], idx);
407
548
c10::cast_and_store<arg0_t >(dtypes[0 ], out, result);
0 commit comments