Skip to content

Commit f4b3aac

Browse files
maarquitos14mikolaj-pirog
authored andcommitted
Fix translation of Shuffle ops for sycl::bfloat16 and sycl::half (#3231)
Extend KhronosGroup/SPIRV-LLVM-Translator#2339 to support also `OpGroupNonUniformShuffle`, `OpGroupNonUniformShuffleUp`, and `OpGroupNonUniformShuffleXor`. Original commit: KhronosGroup/SPIRV-LLVM-Translator@a61deeef1dd3e35
1 parent 8120994 commit f4b3aac

File tree

3 files changed

+99
-57
lines changed

3 files changed

+99
-57
lines changed

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6840,7 +6840,14 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
68406840
transValue(CI->getArgOperand(2), BB), BB);
68416841
return BM->addStoreInst(transValue(CI->getArgOperand(0), BB), V, {}, BB);
68426842
}
6843-
case OpGroupNonUniformShuffleDown: {
6843+
case OpCooperativeMatrixLengthKHR: {
6844+
return BM->addCooperativeMatrixLengthKHRInst(
6845+
transScavengedType(CI), transType(CI->getArgOperand(0)->getType()), BB);
6846+
}
6847+
case OpGroupNonUniformShuffle:
6848+
case OpGroupNonUniformShuffleDown:
6849+
case OpGroupNonUniformShuffleUp:
6850+
case OpGroupNonUniformShuffleXor: {
68446851
Function *F = CI->getCalledFunction();
68456852
if (F->arg_size() && F->getArg(0)->hasStructRetAttr()) {
68466853
StructType *St = cast<StructType>(F->getParamStructRetType(0));
@@ -6857,9 +6864,8 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
68576864
SPIRVType *ElementTy = transType(MemberTy);
68586865
SPIRVValue *Element0 =
68596866
BM->addCompositeExtractInst(ElementTy, Composite0, {0}, BB);
6860-
SPIRVValue *Src =
6861-
BM->addGroupInst(OpGroupNonUniformShuffleDown, ElementTy,
6862-
static_cast<Scope>(ScopeId), {Element0, Delta}, BB);
6867+
SPIRVValue *Src = BM->addGroupInst(
6868+
OC, ElementTy, static_cast<Scope>(ScopeId), {Element0, Delta}, BB);
68636869
SPIRVValue *Composite1 =
68646870
BM->addCompositeInsertInst(Src, Composite0, {0}, BB);
68656871
return BM->addStoreInst(InValue, Composite1, {}, BB);
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -spirv-text -o - | FileCheck --check-prefixes=CHECK-SPIRV,CHECK-SPIRV-TYPED-PTR %s
3+
; RUN: llvm-spirv %t.bc -o %t.spv
4+
; RUN: spirv-val %t.spv
5+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
6+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck --check-prefix CHECK-LLVM %s
7+
8+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_untyped_pointers -spirv-text -o - | FileCheck --check-prefixes=CHECK-SPIRV,CHECK-SPIRV-UNTYPED-PTR %s
9+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_untyped_pointers -o %t.spv
10+
; RUN: spirv-val %t.spv
11+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
12+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck --check-prefix CHECK-LLVM %s
13+
14+
; CHECK-SPIRV-DAG: TypeInt [[#I32:]] 32 0
15+
; CHECK-SPIRV-DAG: Constant [[#I32]] [[#CONST_I32_3:]] 3
16+
; CHECK-SPIRV-DAG: Constant [[#I32]] [[#CONST_I32_8:]] 8
17+
; CHECK-SPIRV-DAG: TypeFloat [[#HALF:]] 16
18+
; CHECK-SPIRV-DAG: TypeStruct [[#S_HALF:]] [[#HALF]]
19+
; CHECK-SPIRV-TYPED-PTR-DAG: TypePointer [[#PTR_S_HALF:]] {{[0-9]+}} [[#S_HALF]]
20+
; CHECK-SPIRV-UNTYPED-PTR-DAG: TypeUntypedPointerKHR [[#PTR:]] [[#]]
21+
22+
target triple = "spir64-unknown-unknown"
23+
24+
%"class.sycl::_V1::detail::half_impl::half" = type { half }
25+
26+
define spir_func void @test_group_non_uniform_shuffle_down() {
27+
entry:
28+
%agg.tmp.i.i = alloca %"class.sycl::_V1::detail::half_impl::half", align 2
29+
%ref.tmp.i = alloca %"class.sycl::_V1::detail::half_impl::half", align 2
30+
%ref.tmp.ascast.i = addrspacecast ptr %ref.tmp.i to ptr addrspace(4)
31+
call spir_func void @_Z30__spirv_GroupNonUniformShuffleIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2 %ref.tmp.ascast.i, i32 noundef 3, ptr noundef nonnull byval(%"class.sycl::_V1::detail::half_impl::half") align 2 %agg.tmp.i.i, i32 noundef 8)
32+
call spir_func void @_Z34__spirv_GroupNonUniformShuffleDownIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2 %ref.tmp.ascast.i, i32 noundef 3, ptr noundef nonnull byval(%"class.sycl::_V1::detail::half_impl::half") align 2 %agg.tmp.i.i, i32 noundef 8)
33+
call spir_func void @_Z32__spirv_GroupNonUniformShuffleUpIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2 %ref.tmp.ascast.i, i32 noundef 3, ptr noundef nonnull byval(%"class.sycl::_V1::detail::half_impl::half") align 2 %agg.tmp.i.i, i32 noundef 8)
34+
call spir_func void @_Z33__spirv_GroupNonUniformShuffleXorIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2 %ref.tmp.ascast.i, i32 noundef 3, ptr noundef nonnull byval(%"class.sycl::_V1::detail::half_impl::half") align 2 %agg.tmp.i.i, i32 noundef 8)
35+
ret void
36+
}
37+
38+
; CHECK-SPIRV-TYPED-PTR: Variable {{[0-9]+}} {{[0-9]+}}
39+
; CHECK-SPIRV-TYPED-PTR: Variable [[#PTR_S_HALF]] [[#VAR_0:]]
40+
; CHECK-SPIRV-UNTYPED-PTR: UntypedVariableKHR {{[0-9]+}} {{[0-9]+}}
41+
; CHECK-SPIRV-UNTYPED-PTR: UntypedVariableKHR [[#PTR]] [[#VAR_0:]] [[#HALF]]
42+
; CHECK-SPIRV: Load [[#S_HALF]] [[#COMP_0:]] [[#VAR_0]]
43+
; CHECK-SPIRV: CompositeExtract [[#HALF]] [[#ELEM_0:]] [[#COMP_0]] 0
44+
; CHECK-SPIRV: GroupNonUniformShuffle [[#HALF]] [[#ELEM_1:]] [[#CONST_I32_3]] [[#ELEM_0]] [[#CONST_I32_8]]
45+
; CHECK-SPIRV: CompositeInsert [[#S_HALF]] [[#COMP_1:]] [[#ELEM_1]] [[#COMP_0]] 0
46+
; CHECK-SPIRV: Store [[#VAR_0]] [[#COMP_1]]
47+
; CHECK-SPIRV: Load [[#S_HALF]] [[#DOWN_COMP_0:]] [[#VAR_0]]
48+
; CHECK-SPIRV: CompositeExtract [[#HALF]] [[#DOWN_ELEM_0:]] [[#DOWN_COMP_0]] 0
49+
; CHECK-SPIRV: GroupNonUniformShuffleDown [[#HALF]] [[#DOWN_ELEM_1:]] [[#CONST_I32_3]] [[#DOWN_ELEM_0]] [[#CONST_I32_8]]
50+
; CHECK-SPIRV: CompositeInsert [[#S_HALF]] [[#DOWN_COMP_1:]] [[#DOWN_ELEM_1]] [[#DOWN_COMP_0]] 0
51+
; CHECK-SPIRV: Store [[#VAR_0]] [[#DOWN_COMP_1]]
52+
; CHECK-SPIRV: Load [[#S_HALF]] [[#UP_COMP_0:]] [[#VAR_0]]
53+
; CHECK-SPIRV: CompositeExtract [[#HALF]] [[#UP_ELEM_0:]] [[#UP_COMP_0]] 0
54+
; CHECK-SPIRV: GroupNonUniformShuffleUp [[#HALF]] [[#UP_ELEM_1:]] [[#CONST_I32_3]] [[#UP_ELEM_0]] [[#CONST_I32_8]]
55+
; CHECK-SPIRV: CompositeInsert [[#S_HALF]] [[#UP_COMP_1:]] [[#UP_ELEM_1]] [[#UP_COMP_0]] 0
56+
; CHECK-SPIRV: Store [[#VAR_0]] [[#UP_COMP_1]]
57+
; CHECK-SPIRV: Load [[#S_HALF]] [[#XOR_COMP_0:]] [[#VAR_0]]
58+
; CHECK-SPIRV: CompositeExtract [[#HALF]] [[#XOR_ELEM_0:]] [[#XOR_COMP_0]] 0
59+
; CHECK-SPIRV: GroupNonUniformShuffleXor [[#HALF]] [[#XOR_ELEM_1:]] [[#CONST_I32_3]] [[#XOR_ELEM_0]] [[#CONST_I32_8]]
60+
; CHECK-SPIRV: CompositeInsert [[#S_HALF]] [[#XOR_COMP_1:]] [[#XOR_ELEM_1]] [[#XOR_COMP_0]] 0
61+
; CHECK-SPIRV: Store [[#VAR_0]] [[#XOR_COMP_1]]
62+
63+
; CHECK-LLVM: [[ALLOCA_0:%[a-z0-9.]+]] = alloca %"class.sycl::_V1::detail::half_impl::half", align 2
64+
; CHECK-LLVM: [[ALLOCA_1:%[a-z0-9.]+]] = alloca %"class.sycl::_V1::detail::half_impl::half", align 2
65+
; CHECK-LLVM: [[LOAD_0:%[a-z0-9.]+]] = load %"class.sycl::_V1::detail::half_impl::half", ptr [[ALLOCA_1]], align 2
66+
; CHECK-LLVM: [[EXTRACT_0:%[a-z0-9.]+]] = extractvalue %"class.sycl::_V1::detail::half_impl::half" [[LOAD_0]], 0
67+
; CHECK-LLVM: [[CALL_0:%[a-z0-9.]+]] = call spir_func half @_Z17sub_group_shuffleDhj(half [[EXTRACT_0]], i32 8)
68+
; CHECK-LLVM: [[INSERT_0:%[a-z0-9.]+]] = insertvalue %"class.sycl::_V1::detail::half_impl::half" [[LOAD_0]], half [[CALL_0]], 0
69+
; CHECK-LLVM: store %"class.sycl::_V1::detail::half_impl::half" [[INSERT_0]], ptr [[ALLOCA_1]], align 2
70+
; CHECK-LLVM: [[DOWN_LOAD_0:%[a-z0-9.]+]] = load %"class.sycl::_V1::detail::half_impl::half", ptr [[ALLOCA_1]], align 2
71+
; CHECK-LLVM: [[DOWN_EXTRACT_0:%[a-z0-9.]+]] = extractvalue %"class.sycl::_V1::detail::half_impl::half" [[DOWN_LOAD_0]], 0
72+
; CHECK-LLVM: [[DOWN_CALL_0:%[a-z0-9.]+]] = call spir_func half @_Z22sub_group_shuffle_downDhj(half [[DOWN_EXTRACT_0]], i32 8)
73+
; CHECK-LLVM: [[DOWN_INSERT_0:%[a-z0-9.]+]] = insertvalue %"class.sycl::_V1::detail::half_impl::half" [[DOWN_LOAD_0]], half [[DOWN_CALL_0]], 0
74+
; CHECK-LLVM: store %"class.sycl::_V1::detail::half_impl::half" [[DOWN_INSERT_0]], ptr [[ALLOCA_1]], align 2
75+
; CHECK-LLVM: [[UP_LOAD_0:%[a-z0-9.]+]] = load %"class.sycl::_V1::detail::half_impl::half", ptr [[ALLOCA_1]], align 2
76+
; CHECK-LLVM: [[UP_EXTRACT_0:%[a-z0-9.]+]] = extractvalue %"class.sycl::_V1::detail::half_impl::half" [[UP_LOAD_0]], 0
77+
; CHECK-LLVM: [[UP_CALL_0:%[a-z0-9.]+]] = call spir_func half @_Z20sub_group_shuffle_upDhj(half [[UP_EXTRACT_0]], i32 8)
78+
; CHECK-LLVM: [[UP_INSERT_0:%[a-z0-9.]+]] = insertvalue %"class.sycl::_V1::detail::half_impl::half" [[UP_LOAD_0]], half [[UP_CALL_0]], 0
79+
; CHECK-LLVM: store %"class.sycl::_V1::detail::half_impl::half" [[UP_INSERT_0]], ptr [[ALLOCA_1]], align 2
80+
; CHECK-LLVM: [[XOR_LOAD_0:%[a-z0-9.]+]] = load %"class.sycl::_V1::detail::half_impl::half", ptr [[ALLOCA_1]], align 2
81+
; CHECK-LLVM: [[XOR_EXTRACT_0:%[a-z0-9.]+]] = extractvalue %"class.sycl::_V1::detail::half_impl::half" [[XOR_LOAD_0]], 0
82+
; CHECK-LLVM: [[XOR_CALL_0:%[a-z0-9.]+]] = call spir_func half @_Z21sub_group_shuffle_xorDhj(half [[XOR_EXTRACT_0]], i32 8)
83+
; CHECK-LLVM: [[XOR_INSERT_0:%[a-z0-9.]+]] = insertvalue %"class.sycl::_V1::detail::half_impl::half" [[XOR_LOAD_0]], half [[XOR_CALL_0]], 0
84+
; CHECK-LLVM: store %"class.sycl::_V1::detail::half_impl::half" [[XOR_INSERT_0]], ptr [[ALLOCA_1]], align 2
85+
86+
declare dso_local spir_func void @_Z30__spirv_GroupNonUniformShuffleIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef, ptr noundef byval(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef) local_unnamed_addr
87+
declare dso_local spir_func void @_Z34__spirv_GroupNonUniformShuffleDownIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef, ptr noundef byval(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef) local_unnamed_addr
88+
declare dso_local spir_func void @_Z32__spirv_GroupNonUniformShuffleUpIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef, ptr noundef byval(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef) local_unnamed_addr
89+
declare dso_local spir_func void @_Z33__spirv_GroupNonUniformShuffleXorIN4sycl3_V16detail9half_impl4halfEET_N5__spv5Scope4FlagES5_j(ptr addrspace(4) dead_on_unwind writable sret(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef, ptr noundef byval(%"class.sycl::_V1::detail::half_impl::half") align 2, i32 noundef) local_unnamed_addr

llvm-spirv/test/group_non_uniform_shuffle_down.ll

Lines changed: 0 additions & 53 deletions
This file was deleted.

0 commit comments

Comments
 (0)