Skip to content

Commit 68272ab

Browse files
malfetpytorchmergebot
authored andcommitted
Extend cuda_flip to unsigned types (pytorch#137781)
Using AT_DISPATCH_V2 Test plan: `python3 -c "import torch;print(torch.randint(0, 100, (4, 4), dtype=torch.uint16, device='cuda').flip(0))"` Fixes pytorch#137770 Pull Request resolved: pytorch#137781 Approved by: https://github.com/Skylion007
1 parent 4fa46d3 commit 68272ab

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

aten/src/ATen/native/cuda/IndexKernel.cu

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -457,12 +457,20 @@ void flip_kernel(TensorIterator& iter, const bool quantized) {
457457
flip_kernel_impl<dtype>(iter);
458458
});
459459
} else {
460-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
461-
iter.dtype(), "flip_cuda",
462-
[&] {
463-
using dtype = OpaqueType<sizeof(scalar_t)>;
464-
flip_kernel_impl<dtype>(iter);
465-
});
460+
AT_DISPATCH_V2(
461+
iter.dtype(),
462+
"flip_cuda",
463+
AT_WRAP([&] {
464+
using dtype = OpaqueType<sizeof(scalar_t)>;
465+
flip_kernel_impl<dtype>(iter);
466+
}),
467+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
468+
AT_EXPAND(AT_FLOAT8_TYPES),
469+
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
470+
kComplexHalf,
471+
kHalf,
472+
kBool,
473+
kBFloat16);
466474
}
467475
}
468476

0 commit comments

Comments
 (0)