Skip to content

Commit a90aaa7

Browse files
authored
[SYCL][Fusion] Refine remapping of GEP instruction during internalization (#12128)
So far, we distinguished GEP instructions that select an element of the internalized buffer (must be remapped) and GEPs that address into an aggregate type (must _not_ be remapped) by looking at the number of indices. However, we can also encounter single-index GEP instructions that use a byte-offset to address into padded structures, as well as multi-index GEPs with a base pointer offset that address into an aggregate _and_ need to be remapped. This PR adds uses the newly added element size information (#12108) to correctly distinguish the required action for these kinds of GEP instructions. In addition to the E2E tests, target-specific lit tests derived from it are also added to demonstrate the subtle differences among SPIR-V, CUDA and HIP. --------- Signed-off-by: Julian Oppermann <[email protected]>
1 parent 8ab4907 commit a90aaa7

File tree

5 files changed

+498
-24
lines changed

5 files changed

+498
-24
lines changed

sycl-fusion/passes/internalization/Internalization.cpp

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <llvm/ADT/BitVector.h>
1515
#include <llvm/ADT/TypeSwitch.h>
1616
#include <llvm/IR/IRBuilder.h>
17+
#include <llvm/IR/PatternMatch.h>
1718
#include <llvm/Support/WithColor.h>
1819
#include <llvm/Transforms/Utils/Cloning.h>
1920

@@ -25,6 +26,7 @@
2526
#define DEBUG_TYPE "sycl-fusion"
2627

2728
using namespace llvm;
29+
using namespace PatternMatch;
2830

2931
constexpr static StringLiteral PrivatePromotion{"private"};
3032
constexpr static StringLiteral LocalPromotion{"local"};
@@ -191,22 +193,10 @@ static void updateInternalizationMD(Function *F, StringRef Kind,
191193
/// address space has changed from N to N / LocalSize.
192194
static void remap(GetElementPtrInst *GEPI, const PromotionInfo &PromInfo) {
193195
IRBuilder<> Builder{GEPI};
194-
Value *C0 = Builder.getInt64(0);
195-
196-
auto NIdx = GEPI->getNumIndices();
197-
if (NIdx > 1) {
198-
// `GEPI` indexes into an aggregate. If the first index is 0, the base
199-
// pointer is used as-is and we do not need to perform remapping. This is
200-
// the common case.
201-
// TODO: Support non-zero pointer offset, too. If the pointer operand is
202-
// a GEP as well, we must check if the source element types match.
203-
assert(GEPI->idx_begin()->get() == C0);
204-
return;
205-
}
206196

207197
if (PromInfo.LocalSize == 1) {
208198
// Squash the index and let instcombine clean-up afterwards.
209-
GEPI->idx_begin()->set(C0);
199+
GEPI->idx_begin()->set(Builder.getInt64(0));
210200
return;
211201
}
212202

@@ -290,6 +280,43 @@ Error SYCLInternalizerImpl::canPromoteCall(CallBase *C, const Value *Val,
290280
return Error::success();
291281
}
292282

283+
enum GEPKind { INVALID = 0, NEEDS_REMAPPING, ADDRESSES_INTO_AGGREGATE };
284+
285+
static int getGEPKind(GetElementPtrInst *GEPI, const PromotionInfo &PromInfo) {
286+
assert(GEPI->getNumIndices() >= 1 && "No-op GEP encountered");
287+
288+
// Inspect the GEP's source element type.
289+
auto &DL = GEPI->getModule()->getDataLayout();
290+
auto SrcElemTySz = DL.getTypeAllocSize(GEPI->getSourceElementType());
291+
292+
// `GEPI`'s first index is selecting elements. Unless it is constant zero, we
293+
// have to remap. If there are more indices, we start to address into an
294+
// aggregate type.
295+
if (SrcElemTySz == PromInfo.ElemSize) {
296+
int Kind = INVALID;
297+
if (!match(GEPI->idx_begin()->get(), m_ZeroInt()))
298+
Kind |= NEEDS_REMAPPING;
299+
if (GEPI->getNumIndices() >= 2)
300+
Kind |= ADDRESSES_INTO_AGGREGATE;
301+
assert(Kind != INVALID && "No-op GEP encountered");
302+
return Kind;
303+
}
304+
305+
// Check whether `GEPI` adds a constant offset, e.g. a byte offset to address
306+
// into a padded structure, smaller than the element size.
307+
MapVector<Value *, APInt> VariableOffsets;
308+
auto IW = DL.getIndexSizeInBits(GEPI->getPointerAddressSpace());
309+
APInt ConstantOffset = APInt::getZero(IW);
310+
if (GEPI->collectOffset(DL, IW, VariableOffsets, ConstantOffset) &&
311+
VariableOffsets.empty() &&
312+
ConstantOffset.getZExtValue() < PromInfo.ElemSize) {
313+
return ADDRESSES_INTO_AGGREGATE;
314+
}
315+
316+
// We don't know what `GEPI` addresses; bail out.
317+
return INVALID;
318+
}
319+
293320
Error SYCLInternalizerImpl::canPromoteGEP(GetElementPtrInst *GEPI,
294321
const Value *Val,
295322
const PromotionInfo &PromInfo,
@@ -299,12 +326,17 @@ Error SYCLInternalizerImpl::canPromoteGEP(GetElementPtrInst *GEPI,
299326
// required.
300327
return Error::success();
301328
}
302-
// Recurse to check all users of the GEP. We are either already in
303-
// `InAggregate` mode, or inspect the current instruction. Recall that a GEP's
304-
// first index is used to step through the base pointer, whereas any
305-
// additional indices represent addressing into an aggregrate type.
329+
330+
// Inspect the current instruction.
331+
auto Kind = getGEPKind(GEPI, PromInfo);
332+
if (Kind == INVALID) {
333+
return createStringError(inconvertibleErrorCode(),
334+
"Unsupported pointer arithmetic");
335+
}
336+
337+
// Recurse to check all users of the GEP.
306338
return canPromoteValue(GEPI, PromInfo,
307-
InAggregate || GEPI->getNumIndices() >= 2);
339+
InAggregate || (Kind & ADDRESSES_INTO_AGGREGATE));
308340
}
309341

310342
Error SYCLInternalizerImpl::canPromoteValue(Value *Val,
@@ -423,15 +455,17 @@ void SYCLInternalizerImpl::promoteGEPI(GetElementPtrInst *GEPI,
423455
bool InAggregate) const {
424456
// Not PointerType is unreachable. Other case is caught in caller.
425457
if (cast<PointerType>(GEPI->getType())->getAddressSpace() != AS) {
426-
if (!InAggregate)
458+
auto Kind = getGEPKind(GEPI, PromInfo);
459+
assert(Kind != INVALID);
460+
461+
if (!InAggregate && (Kind & NEEDS_REMAPPING)) {
427462
remap(GEPI, PromInfo);
463+
}
428464
GEPI->mutateType(PointerType::get(GEPI->getContext(), AS));
429-
// Recurse to promote to all users of the GEP. We are either already in
430-
// `InAggregate` mode, or inspect the current instruction. Recall that a
431-
// GEP's first index is used to step through the base pointer, whereas any
432-
// additional indices represent addressing into an aggregrate type.
465+
466+
// Recurse to promote to all users of the GEP.
433467
return promoteValue(GEPI, PromInfo,
434-
InAggregate || GEPI->getNumIndices() >= 2);
468+
InAggregate || (Kind & ADDRESSES_INTO_AGGREGATE));
435469
}
436470
}
437471

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
; REQUIRES: cuda
2+
; RUN: opt -load-pass-plugin %shlibdir/SYCLKernelFusion%shlibext \
3+
; RUN: -passes=sycl-internalization -S %s | FileCheck %s
4+
5+
; This test is a reduced IR version of
6+
; sycl/test-e2e/KernelFusion/internalize_non_unit_localsize.cpp for CUDA
7+
8+
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
9+
target triple = "nvptx64-nvidia-cuda"
10+
11+
%"class.sycl::_V1::range" = type { %"class.sycl::_V1::detail::array" }
12+
%"class.sycl::_V1::detail::array" = type { [1 x i64] }
13+
%struct.MyStruct = type { i32, %"class.sycl::_V1::vec" }
14+
%"class.sycl::_V1::vec" = type { <3 x i32> }
15+
16+
declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #0
17+
declare noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #0
18+
declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
19+
declare ptr @llvm.nvvm.implicit.offset() #1
20+
21+
define void @fused_0(ptr addrspace(1) nocapture noundef align 16 %KernelOne__arg_accTmp,
22+
ptr nocapture noundef readonly byval(%"class.sycl::_V1::range") align 8 %KernelOne__arg_accTmp3,
23+
ptr addrspace(1) nocapture noundef readonly align 4 %KernelOne__arg_accIn,
24+
ptr nocapture noundef readonly byval(%"class.sycl::_V1::range") align 8 %KernelOne__arg_accIn6,
25+
ptr addrspace(1) nocapture noundef align 1 %KernelOne__arg_accTmp27,
26+
ptr nocapture noundef readonly byval(%"class.sycl::_V1::range") align 8 %KernelOne__arg_accTmp210,
27+
ptr addrspace(1) nocapture noundef writeonly align 4 %KernelTwo__arg_accOut,
28+
ptr nocapture noundef readonly byval(%"class.sycl::_V1::range") align 8 %KernelTwo__arg_accOut3)
29+
local_unnamed_addr #3 !sycl.kernel.promote !13 !sycl.kernel.promote.localsize !14 !sycl.kernel.promote.elemsize !15 {
30+
; CHECK-LABEL: define void @fused_0(
31+
; CHECK-SAME: ptr nocapture noundef readonly byval(%"class.sycl::_V1::range") align 8 [[KERNELONE__ARG_ACCTMP3:%[^,]*accTmp3]],
32+
; CHECK-SAME: ptr nocapture noundef readonly byval(%"class.sycl::_V1::range") align 8 [[KERNELONE__ARG_ACCTMP210:%[^,]*accTmp210]]
33+
; CHECK: entry:
34+
; CHECK: [[TMP0:%.*]] = alloca i8, i64 3, align 1
35+
; CHECK: [[TMP1:%.*]] = alloca i8, i64 96, align 16
36+
; CHECK: [[KERNELONE__ARG_ACCTMP2103_SROA_0_0_COPYLOAD:%.*]] = load i64, ptr [[KERNELONE__ARG_ACCTMP210]], align 8
37+
; CHECK: [[KERNELONE__ARG_ACCTMP31_SROA_0_0_COPYLOAD:%.*]] = load i64, ptr [[KERNELONE__ARG_ACCTMP3]], align 8
38+
; CHECK: [[TMP2:%.*]] = urem i64 [[KERNELONE__ARG_ACCTMP31_SROA_0_0_COPYLOAD]], 3
39+
; CHECK: [[TMP3:%.*]] = urem i64 [[KERNELONE__ARG_ACCTMP2103_SROA_0_0_COPYLOAD]], 3
40+
; CHECK: [[MUL:%.*]] = mul nuw nsw i64 [[GLOBAL_ID:.*]], 3
41+
; CHECK: [[ADD:%.*]] = add nuw nsw i64 [[MUL]], 1
42+
; CHECK: [[TMP10:%.*]] = add i64 [[TMP2]], [[ADD]]
43+
; CHECK: [[TMP11:%.*]] = urem i64 [[TMP10]], 3
44+
; CHECK: [[ARRAYIDX_1:%.*]] = getelementptr inbounds %struct.MyStruct, ptr [[TMP1]], i64 [[TMP11]]
45+
46+
; COM: This i8-GEP _was_ not remapped because it addresses into a single MyStruct element
47+
; CHECK: [[ARRAYIDX_2:%.*]] = getelementptr inbounds i8, ptr [[ARRAYIDX_1]], i64 20
48+
; CHECK: store i32 {{.*}}, ptr [[ARRAYIDX_2]], align 4
49+
; CHECK: [[TMP12:%.*]] = add i64 [[TMP3]], [[ADD]]
50+
; CHECK: [[TMP13:%.*]] = urem i64 [[TMP12]], 3
51+
52+
; COM: This i8-GEP was remapped because it selects an element of the underlying i8-buffer
53+
; CHECK: [[ARRAYIDX_3:%.*]] = getelementptr inbounds i8, ptr [[TMP0]], i64 [[TMP13]]
54+
55+
; CHECK: store i8 {{.*}}, ptr [[ARRAYIDX_3]], align 1
56+
; CHECK: store i32 {{.*}}, ptr addrspace(1)
57+
; CHECK: ret void
58+
;
59+
entry:
60+
%KernelOne__arg_accTmp2103.sroa.0.0.copyload = load i64, ptr %KernelOne__arg_accTmp210, align 8
61+
%KernelOne__arg_accIn62.sroa.0.0.copyload = load i64, ptr %KernelOne__arg_accIn6, align 8
62+
%KernelOne__arg_accTmp31.sroa.0.0.copyload = load i64, ptr %KernelOne__arg_accTmp3, align 8
63+
%add.ptr.j2 = getelementptr inbounds %struct.MyStruct, ptr addrspace(1) %KernelOne__arg_accTmp, i64 %KernelOne__arg_accTmp31.sroa.0.0.copyload
64+
%add.ptr.i37.i = getelementptr inbounds i32, ptr addrspace(1) %KernelOne__arg_accIn, i64 %KernelOne__arg_accIn62.sroa.0.0.copyload
65+
%add.ptr.i43.i = getelementptr inbounds i8, ptr addrspace(1) %KernelOne__arg_accTmp27, i64 %KernelOne__arg_accTmp2103.sroa.0.0.copyload
66+
%0 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
67+
%conv.i1.j7 = sext i32 %0 to i64
68+
%1 = tail call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
69+
%conv.i3.j7 = sext i32 %1 to i64
70+
%mul.j7 = mul nsw i64 %conv.i3.j7, %conv.i1.j7
71+
%2 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
72+
%conv.i2.j7 = sext i32 %2 to i64
73+
%add.j7 = add nsw i64 %mul.j7, %conv.i2.j7
74+
%3 = tail call ptr @llvm.nvvm.implicit.offset()
75+
%4 = load i32, ptr %3, align 4
76+
%conv.j8 = zext i32 %4 to i64
77+
%add4.j7 = add nsw i64 %add.j7, %conv.j8
78+
%mul.j2 = mul nuw nsw i64 %add4.j7, 3
79+
%add.j2 = add nuw nsw i64 %mul.j2, 1
80+
%arrayidx.j2 = getelementptr inbounds i32, ptr addrspace(1) %add.ptr.i37.i, i64 %add.j2
81+
%5 = load i32, ptr addrspace(1) %arrayidx.j2, align 4
82+
%arrayidx.i55.i = getelementptr inbounds %struct.MyStruct, ptr addrspace(1) %add.ptr.j2, i64 %add.j2
83+
%arrayidx.j3 = getelementptr inbounds i8, ptr addrspace(1) %arrayidx.i55.i, i64 20
84+
store i32 %5, ptr addrspace(1) %arrayidx.j3, align 4
85+
%conv.j2 = trunc i32 %5 to i8
86+
%arrayidx.i73.i = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i43.i, i64 %add.j2
87+
store i8 %conv.j2, ptr addrspace(1) %arrayidx.i73.i, align 1
88+
%KernelTwo__arg_accOut34.sroa.0.0.copyload = load i64, ptr %KernelTwo__arg_accOut3, align 8
89+
%add.ptr.i.i7 = getelementptr inbounds i32, ptr addrspace(1) %KernelTwo__arg_accOut, i64 %KernelTwo__arg_accOut34.sroa.0.0.copyload
90+
%6 = load i32, ptr %3, align 4
91+
%conv.j7.i13 = zext i32 %6 to i64
92+
%add4.j6.i14 = add nsw i64 %add.j7, %conv.j7.i13
93+
%mul.i.i16 = mul nuw nsw i64 %add4.j6.i14, 3
94+
%add.i45.i = add nuw nsw i64 %mul.i.i16, 1
95+
%arrayidx.i.i17 = getelementptr inbounds %struct.MyStruct, ptr addrspace(1) %add.ptr.j2, i64 %add.i45.i
96+
%arrayidx.j2.i19 = getelementptr inbounds i8, ptr addrspace(1) %arrayidx.i.i17, i64 20
97+
%7 = load i32, ptr addrspace(1) %arrayidx.j2.i19, align 4
98+
%arrayidx.i55.i20 = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i43.i, i64 %add.i45.i
99+
%8 = load i8, ptr addrspace(1) %arrayidx.i55.i20, align 1
100+
%conv.i.i22 = sext i8 %8 to i32
101+
%add.i.i23 = add nsw i32 %7, %conv.i.i22
102+
%arrayidx.i64.i = getelementptr inbounds i32, ptr addrspace(1) %add.ptr.i.i7, i64 %add.i45.i
103+
store i32 %add.i.i23, ptr addrspace(1) %arrayidx.i64.i, align 4
104+
ret void
105+
}
106+
107+
attributes #0 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) }
108+
attributes #1 = { nofree nosync nounwind speculatable memory(none) }
109+
attributes #3 = { nofree nosync nounwind memory(read, argmem: readwrite, inaccessiblemem: write) "frame-pointer"="all" "target-cpu"="sm_80" "target-features"="+ptx82,+sm_80" "uniform-work-group-size"="true" }
110+
111+
!nvvm.annotations = !{!10}
112+
113+
!10 = !{ptr @fused_0, !"kernel", i32 1}
114+
!13 = !{!"private", !"none", !"none", !"none", !"private", !"none", !"none", !"none"}
115+
!14 = !{i64 3, !"", !"", !"", i64 3, !"", !"", !""}
116+
!15 = !{i64 32, !"", !"", !"", i64 1, !"", !"", !""}

0 commit comments

Comments
 (0)