Skip to content

Commit a61deee

Browse files
authored
Fix translation of Shuffle ops for sycl::bfloat16 and sycl::half (#3231)
Extend #2339 to support also `OpGroupNonUniformShuffle`, `OpGroupNonUniformShuffleUp`, and `OpGroupNonUniformShuffleXor`.
1 parent 221039c commit a61deee

File tree

3 files changed

+95
-57
lines changed

3 files changed

+95
-57
lines changed

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6846,7 +6846,10 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
68466846
return BM->addCooperativeMatrixLengthKHRInst(
68476847
transScavengedType(CI), transType(CI->getArgOperand(0)->getType()), BB);
68486848
}
6849-
case OpGroupNonUniformShuffleDown: {
6849+
case OpGroupNonUniformShuffle:
6850+
case OpGroupNonUniformShuffleDown:
6851+
case OpGroupNonUniformShuffleUp:
6852+
case OpGroupNonUniformShuffleXor: {
68506853
Function *F = CI->getCalledFunction();
68516854
if (F->arg_size() && F->getArg(0)->hasStructRetAttr()) {
68526855
StructType *St = cast<StructType>(F->getParamStructRetType(0));
@@ -6863,9 +6866,8 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
68636866
SPIRVType *ElementTy = transType(MemberTy);
68646867
SPIRVValue *Element0 =
68656868
BM->addCompositeExtractInst(ElementTy, Composite0, {0}, BB);
6866-
SPIRVValue *Src =
6867-
BM->addGroupInst(OpGroupNonUniformShuffleDown, ElementTy,
6868-
static_cast<Scope>(ScopeId), {Element0, Delta}, BB);
6869+
SPIRVValue *Src = BM->addGroupInst(
6870+
OC, ElementTy, static_cast<Scope>(ScopeId), {Element0, Delta}, BB);
68696871
SPIRVValue *Composite1 =
68706872
BM->addCompositeInsertInst(Src, Composite0, {0}, BB);
68716873
return BM->addStoreInst(InValue, Composite1, {}, BB);

test/group_non_uniform_shuffle.ll

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -spirv-text -o - | FileCheck --check-prefixes=CHECK-SPIRV,CHECK-SPIRV-TYPED-PTR %s
3+
; RUN: llvm-spirv %t.bc -o %t.spv
4+
; RUN: spirv-val %t.spv
5+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
6+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck --check-prefix CHECK-LLVM %s
7+
8+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_untyped_pointers -spirv-text -o - | FileCheck --check-prefixes=CHECK-SPIRV,CHECK-SPIRV-UNTYPED-PTR %s
9+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_untyped_pointers -o %t.spv
10+
; RUN: spirv-val %t.spv
11+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
12+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck --check-prefix CHECK-LLVM %s
13+
14+
; CHECK-SPIRV-DAG: TypeInt [[#I32:]] 32 0
15+
; CHECK-SPIRV-DAG: Constant [[#I32]] [[#CONST_I32_3:]] 3
16+
; CHECK-SPIRV-DAG: Constant [[#I32]] [[#CONST_I32_8:]] 8
17+
; CHECK-SPIRV-DAG: TypeFloat [[#HALF:]] 16
18+
; CHECK-SPIRV-DAG: TypeStruct [[#S_HALF:]] [[#HALF]]
19+
; CHECK-SPIRV-TYPED-PTR-DAG: TypePointer [[#PTR_S_HALF:]] {{[0-9]+}} [[#S_HALF]]
20+
; CHECK-SPIRV-UNTYPED-PTR-DAG: TypeUntypedPointerKHR [[#PTR:]] [[#]]
21+
22+
target triple = "spir64-unknown-unknown"
23+
24+
%"class.sycl::_V1::detail::half_impl::half" = type { half }
25+
26+
define spir_func void @test_group_non_uniform_shuffle_down() {
27+
entry:
28+
%agg.tmp.i.i = alloca %"class.sycl::_V1::detail::half_impl::half", align 2
29+
%ref.tmp.i = alloca %"class.sycl::_V1::detail::half_impl::half", align 2
30+
%ref.tmp.ascast.i = addrspacecast ptr %ref.tmp.i to ptr addrspace(4)
31+
call spir_func void @_Z30__spirv_GroupNonUniformShuffleIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2 %ref.tmp.ascast.i, i32 noundef 3, ptr noundef nonnull byval(%"class.sycl::_V1::detail::half_impl::half") align 2 %agg.tmp.i.i, i32 noundef 8)
32+
call spir_func void @_Z34__spirv_GroupNonUniformShuffleDownIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2 %ref.tmp.ascast.i, i32 noundef 3, ptr noundef nonnull byval(%"class.sycl::_V1::detail::half_impl::half") align 2 %agg.tmp.i.i, i32 noundef 8)
33+
call spir_func void @_Z32__spirv_GroupNonUniformShuffleUpIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2 %ref.tmp.ascast.i, i32 noundef 3, ptr noundef nonnull byval(%"class.sycl::_V1::detail::half_impl::half") align 2 %agg.tmp.i.i, i32 noundef 8)
34+
call spir_func void @_Z33__spirv_GroupNonUniformShuffleXorIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2 %ref.tmp.ascast.i, i32 noundef 3, ptr noundef nonnull byval(%"class.sycl::_V1::detail::half_impl::half") align 2 %agg.tmp.i.i, i32 noundef 8)
35+
ret void
36+
}
37+
38+
; CHECK-SPIRV-TYPED-PTR: Variable {{[0-9]+}} {{[0-9]+}}
39+
; CHECK-SPIRV-TYPED-PTR: Variable [[#PTR_S_HALF]] [[#VAR_0:]]
40+
; CHECK-SPIRV-UNTYPED-PTR: UntypedVariableKHR {{[0-9]+}} {{[0-9]+}}
41+
; CHECK-SPIRV-UNTYPED-PTR: UntypedVariableKHR [[#PTR]] [[#VAR_0:]] [[#HALF]]
42+
; CHECK-SPIRV: Load [[#S_HALF]] [[#COMP_0:]] [[#VAR_0]]
43+
; CHECK-SPIRV: CompositeExtract [[#HALF]] [[#ELEM_0:]] [[#COMP_0]] 0
44+
; CHECK-SPIRV: GroupNonUniformShuffle [[#HALF]] [[#ELEM_1:]] [[#CONST_I32_3]] [[#ELEM_0]] [[#CONST_I32_8]]
45+
; CHECK-SPIRV: CompositeInsert [[#S_HALF]] [[#COMP_1:]] [[#ELEM_1]] [[#COMP_0]] 0
46+
; CHECK-SPIRV: Store [[#VAR_0]] [[#COMP_1]]
47+
; CHECK-SPIRV: Load [[#S_HALF]] [[#DOWN_COMP_0:]] [[#VAR_0]]
48+
; CHECK-SPIRV: CompositeExtract [[#HALF]] [[#DOWN_ELEM_0:]] [[#DOWN_COMP_0]] 0
49+
; CHECK-SPIRV: GroupNonUniformShuffleDown [[#HALF]] [[#DOWN_ELEM_1:]] [[#CONST_I32_3]] [[#DOWN_ELEM_0]] [[#CONST_I32_8]]
50+
; CHECK-SPIRV: CompositeInsert [[#S_HALF]] [[#DOWN_COMP_1:]] [[#DOWN_ELEM_1]] [[#DOWN_COMP_0]] 0
51+
; CHECK-SPIRV: Store [[#VAR_0]] [[#DOWN_COMP_1]]
52+
; CHECK-SPIRV: Load [[#S_HALF]] [[#UP_COMP_0:]] [[#VAR_0]]
53+
; CHECK-SPIRV: CompositeExtract [[#HALF]] [[#UP_ELEM_0:]] [[#UP_COMP_0]] 0
54+
; CHECK-SPIRV: GroupNonUniformShuffleUp [[#HALF]] [[#UP_ELEM_1:]] [[#CONST_I32_3]] [[#UP_ELEM_0]] [[#CONST_I32_8]]
55+
; CHECK-SPIRV: CompositeInsert [[#S_HALF]] [[#UP_COMP_1:]] [[#UP_ELEM_1]] [[#UP_COMP_0]] 0
56+
; CHECK-SPIRV: Store [[#VAR_0]] [[#UP_COMP_1]]
57+
; CHECK-SPIRV: Load [[#S_HALF]] [[#XOR_COMP_0:]] [[#VAR_0]]
58+
; CHECK-SPIRV: CompositeExtract [[#HALF]] [[#XOR_ELEM_0:]] [[#XOR_COMP_0]] 0
59+
; CHECK-SPIRV: GroupNonUniformShuffleXor [[#HALF]] [[#XOR_ELEM_1:]] [[#CONST_I32_3]] [[#XOR_ELEM_0]] [[#CONST_I32_8]]
60+
; CHECK-SPIRV: CompositeInsert [[#S_HALF]] [[#XOR_COMP_1:]] [[#XOR_ELEM_1]] [[#XOR_COMP_0]] 0
61+
; CHECK-SPIRV: Store [[#VAR_0]] [[#XOR_COMP_1]]
62+
63+
; CHECK-LLVM: [[ALLOCA_0:%[a-z0-9.]+]] = alloca %"class.sycl::_V1::detail::half_impl::half", align 2
64+
; CHECK-LLVM: [[ALLOCA_1:%[a-z0-9.]+]] = alloca %"class.sycl::_V1::detail::half_impl::half", align 2
65+
; CHECK-LLVM: [[LOAD_0:%[a-z0-9.]+]] = load %"class.sycl::_V1::detail::half_impl::half", ptr [[ALLOCA_1]], align 2
66+
; CHECK-LLVM: [[EXTRACT_0:%[a-z0-9.]+]] = extractvalue %"class.sycl::_V1::detail::half_impl::half" [[LOAD_0]], 0
67+
; CHECK-LLVM: [[CALL_0:%[a-z0-9.]+]] = call spir_func half @_Z17sub_group_shuffleDhj(half [[EXTRACT_0]], i32 8)
68+
; CHECK-LLVM: [[INSERT_0:%[a-z0-9.]+]] = insertvalue %"class.sycl::_V1::detail::half_impl::half" [[LOAD_0]], half [[CALL_0]], 0
69+
; CHECK-LLVM: store %"class.sycl::_V1::detail::half_impl::half" [[INSERT_0]], ptr [[ALLOCA_1]], align 2
70+
; CHECK-LLVM: [[DOWN_LOAD_0:%[a-z0-9.]+]] = load %"class.sycl::_V1::detail::half_impl::half", ptr [[ALLOCA_1]], align 2
71+
; CHECK-LLVM: [[DOWN_EXTRACT_0:%[a-z0-9.]+]] = extractvalue %"class.sycl::_V1::detail::half_impl::half" [[DOWN_LOAD_0]], 0
72+
; CHECK-LLVM: [[DOWN_CALL_0:%[a-z0-9.]+]] = call spir_func half @_Z22sub_group_shuffle_downDhj(half [[DOWN_EXTRACT_0]], i32 8)
73+
; CHECK-LLVM: [[DOWN_INSERT_0:%[a-z0-9.]+]] = insertvalue %"class.sycl::_V1::detail::half_impl::half" [[DOWN_LOAD_0]], half [[DOWN_CALL_0]], 0
74+
; CHECK-LLVM: store %"class.sycl::_V1::detail::half_impl::half" [[DOWN_INSERT_0]], ptr [[ALLOCA_1]], align 2
75+
; CHECK-LLVM: [[UP_LOAD_0:%[a-z0-9.]+]] = load %"class.sycl::_V1::detail::half_impl::half", ptr [[ALLOCA_1]], align 2
76+
; CHECK-LLVM: [[UP_EXTRACT_0:%[a-z0-9.]+]] = extractvalue %"class.sycl::_V1::detail::half_impl::half" [[UP_LOAD_0]], 0
77+
; CHECK-LLVM: [[UP_CALL_0:%[a-z0-9.]+]] = call spir_func half @_Z20sub_group_shuffle_upDhj(half [[UP_EXTRACT_0]], i32 8)
78+
; CHECK-LLVM: [[UP_INSERT_0:%[a-z0-9.]+]] = insertvalue %"class.sycl::_V1::detail::half_impl::half" [[UP_LOAD_0]], half [[UP_CALL_0]], 0
79+
; CHECK-LLVM: store %"class.sycl::_V1::detail::half_impl::half" [[UP_INSERT_0]], ptr [[ALLOCA_1]], align 2
80+
; CHECK-LLVM: [[XOR_LOAD_0:%[a-z0-9.]+]] = load %"class.sycl::_V1::detail::half_impl::half", ptr [[ALLOCA_1]], align 2
81+
; CHECK-LLVM: [[XOR_EXTRACT_0:%[a-z0-9.]+]] = extractvalue %"class.sycl::_V1::detail::half_impl::half" [[XOR_LOAD_0]], 0
82+
; CHECK-LLVM: [[XOR_CALL_0:%[a-z0-9.]+]] = call spir_func half @_Z21sub_group_shuffle_xorDhj(half [[XOR_EXTRACT_0]], i32 8)
83+
; CHECK-LLVM: [[XOR_INSERT_0:%[a-z0-9.]+]] = insertvalue %"class.sycl::_V1::detail::half_impl::half" [[XOR_LOAD_0]], half [[XOR_CALL_0]], 0
84+
; CHECK-LLVM: store %"class.sycl::_V1::detail::half_impl::half" [[XOR_INSERT_0]], ptr [[ALLOCA_1]], align 2
85+
86+
declare dso_local spir_func void @_Z30__spirv_GroupNonUniformShuffleIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef, ptr noundef byval(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef) local_unnamed_addr
87+
declare dso_local spir_func void @_Z34__spirv_GroupNonUniformShuffleDownIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef, ptr noundef byval(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef) local_unnamed_addr
88+
declare dso_local spir_func void @_Z32__spirv_GroupNonUniformShuffleUpIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef, ptr noundef byval(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef) local_unnamed_addr
89+
declare dso_local spir_func void @_Z33__spirv_GroupNonUniformShuffleXorIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef, ptr noundef byval(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef) local_unnamed_addr

test/group_non_uniform_shuffle_down.ll

Lines changed: 0 additions & 53 deletions
This file was deleted.

0 commit comments

Comments
 (0)