Skip to content

Commit a3e4b9e

Browse files
authored
[SYCL][ESIMD] Fix invalid IR produced by ESIMDOptimizeVecArgCallConv (#9438)
This pass, among other things, replaces a sret pointer argument with an alloca. If the alloca addrspace and argument addrspace do not match, we need to cast. This fixes two cases of invalid IR produced by running tests with `-O0`. --------- Signed-off-by: Sarnie, Nick <[email protected]>
1 parent fda790a commit a3e4b9e

File tree

3 files changed

+48
-5
lines changed

3 files changed

+48
-5
lines changed

llvm/lib/SYCLLowerIR/ESIMD/ESIMDOptimizeVecArgCallConv.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,16 @@ optimizeFunction(Function *OldF,
333333
Align Al = DL.getPrefTypeAlign(T);
334334
unsigned AddrSpace = DL.getAllocaAddrSpace();
335335
AllocaInst *Alloca = new AllocaInst(T, AddrSpace, 0 /*array size*/, Al);
336-
VMap[OldF->getArg(PI.getFormalParam().getArgNo())] = Alloca;
337336
NewInsts.push_back(Alloca);
337+
Instruction *ReplaceInst = Alloca;
338+
if (auto *ArgPtrType = dyn_cast<PointerType>(PI.getFormalParam().getType());
339+
ArgPtrType && ArgPtrType->getAddressSpace() != AddrSpace) {
340+
// If the alloca addrspace and arg addrspace are different,
341+
// insert a cast.
342+
ReplaceInst = new AddrSpaceCastInst(Alloca, ArgPtrType);
343+
NewInsts.push_back(ReplaceInst);
344+
}
345+
VMap[OldF->getArg(PI.getFormalParam().getArgNo())] = ReplaceInst;
338346

339347
if (!PI.isSret()) {
340348
// Create a store of the new optimized parameter into the alloca to
@@ -365,7 +373,11 @@ optimizeFunction(Function *OldF,
365373
IRBuilder<> Bld(RI);
366374
const FormalParamInfo &PI = OptimizeableParams[SretInd];
367375
Argument *OldP = OldF->getArg(PI.getFormalParam().getArgNo());
368-
auto *SretPtr = cast<AllocaInst>(VMap[OldP]);
376+
auto *SretPtr = cast<Instruction>(VMap[OldP]);
377+
if (!isa<AllocaInst>(SretPtr)) {
378+
auto *AddrSpaceCast = cast<AddrSpaceCastInst>(SretPtr);
379+
SretPtr = cast<AllocaInst>(AddrSpaceCast->getPointerOperand());
380+
}
369381
LoadInst *Ld = Bld.CreateLoad(PI.getOptimizedType(), SretPtr);
370382
Bld.CreateRet(Ld);
371383
}

llvm/test/SYCLLowerIR/ESIMD/vec_arg_call_conv.ll

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,15 @@ define dso_local spir_func void @_Z19callee__sret__param(ptr addrspace(4) noalia
5555
; CHECK: define dso_local spir_func <16 x float> @_Z19callee__sret__param(<16 x float> %[[PARAM:.+]])
5656
entry:
5757
; CHECK: %[[ALLOCA1:.+]] = alloca <16 x float>, align 64
58+
; CHECK: %[[CAST1:.+]] = addrspacecast ptr %[[ALLOCA1]] to ptr addrspace(4)
5859
; CHECK: %[[ALLOCA2:.+]] = alloca <16 x float>, align 64
5960
; CHECK: store <16 x float> %[[PARAM]], ptr %[[ALLOCA2]], align 64
6061
%x.ascast = addrspacecast ptr %x to ptr addrspace(4)
6162
; CHECK: %[[ALLOCA2_4:.+]] = addrspacecast ptr %[[ALLOCA2]] to ptr addrspace(4)
6263
%call.i.i.i1 = load <16 x float>, ptr addrspace(4) %x.ascast, align 64
6364
; CHECK: %[[VAL:.+]] = load <16 x float>, ptr addrspace(4) %[[ALLOCA2_4]], align 64
6465
store <16 x float> %call.i.i.i1, ptr addrspace(4) %agg.result, align 64
65-
; CHECK: store <16 x float> %[[VAL]], ptr %[[ALLOCA1]], align 64
66+
; CHECK: store <16 x float> %[[VAL]], ptr addrspace(4) %[[CAST1]], align 64
6667
ret void
6768
; CHECK: %[[RET:.+]] = load <16 x float>, ptr %[[ALLOCA1]], align 64
6869
; CHECK: ret <16 x float> %[[RET]]
@@ -74,6 +75,7 @@ define dso_local spir_func void @_Z29test__sret__fall_through__arr(ptr addrspace
7475
; CHECK: define dso_local spir_func <16 x float> @_Z29test__sret__fall_through__arr(ptr addrspace(4) noundef %[[PARAM0:.+]], i32 noundef %{{.*}})
7576
entry:
7677
; CHECK: %[[ALLOCA1:.+]] = alloca <16 x float>, align 64
78+
; CHECK: %[[CAST1:.+]] = addrspacecast ptr %[[ALLOCA1]] to ptr addrspace(4)
7779
%agg.tmp = alloca %"class.sycl::_V1::ext::intel::esimd::simd", align 64
7880
; CHECK: %[[ALLOCA2:.+]] = alloca %"class.sycl::_V1::ext::intel::esimd::simd", align 64
7981
%agg.tmp.ascast = addrspacecast ptr %agg.tmp to ptr addrspace(4)
@@ -85,7 +87,7 @@ entry:
8587
; CHECK: %[[VAL:.+]] = load <16 x float>, ptr %[[ALLOCA2]], align 64
8688
call spir_func void @_Z19callee__sret__param(ptr addrspace(4) sret(%"class.sycl::_V1::ext::intel::esimd::simd") align 64 %agg.result, ptr noundef nonnull %agg.tmp) #7
8789
; CHECK: %[[RES:.+]] = call spir_func <16 x float> @_Z19callee__sret__param(<16 x float> %[[VAL]])
88-
; CHECK: store <16 x float> %[[RES]], ptr %[[ALLOCA1]], align 64
90+
; CHECK: store <16 x float> %[[RES]], ptr addrspace(4) %[[CAST1]], align 64
8991
ret void
9092
; CHECK: %[[RET:.+]] = load <16 x float>, ptr %[[ALLOCA1]], align 64
9193
; CHECK: ret <16 x float> %[[RET]]
@@ -96,6 +98,7 @@ entry:
9698
define dso_local spir_func void @_Z30test__sret__fall_through__globv(ptr addrspace(4) noalias sret(%"class.sycl::_V1::ext::intel::esimd::simd") align 64 %agg.result) local_unnamed_addr #2 !sycl_explicit_simd !8 !intel_reqd_sub_group_size !9 {
9799
entry:
98100
; CHECK: %[[ALLOCA1:.+]] = alloca <16 x float>, align 64
101+
; CHECK: %[[CAST1:.+]] = addrspacecast ptr %[[ALLOCA1]] to ptr addrspace(4)
99102
%agg.tmp = alloca %"class.sycl::_V1::ext::intel::esimd::simd", align 64
100103
; CHECK: %[[ALLOCA2:.+]] = alloca %"class.sycl::_V1::ext::intel::esimd::simd", align 64
101104
%agg.tmp.ascast = addrspacecast ptr %agg.tmp to ptr addrspace(4)
@@ -105,7 +108,7 @@ entry:
105108
; CHECK: %[[VAL:.+]] = load <16 x float>, ptr %[[ALLOCA2]], align 64
106109
call spir_func void @_Z19callee__sret__param(ptr addrspace(4) sret(%"class.sycl::_V1::ext::intel::esimd::simd") align 64 %agg.result, ptr noundef nonnull %agg.tmp) #7
107110
; CHECK: %[[RES:.+]] = call spir_func <16 x float> @_Z19callee__sret__param(<16 x float> %[[VAL]])
108-
; CHECK: store <16 x float> %[[RES]], ptr %[[ALLOCA1]], align 64
111+
; CHECK: store <16 x float> %[[RES]], ptr addrspace(4) %[[CAST1]], align 64
109112
ret void
110113
; CHECK: %[[RET:.+]] = load <16 x float>, ptr %[[ALLOCA1]], align 64
111114
; CHECK: ret <16 x float> %[[RET]]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: opt -passes=esimd-opt-call-conv -S < %s | FileCheck %s
2+
; This test checks the ESIMDOptimizeVecArgCallConvPass optimization with a
3+
; use of the sret argument relying on the address space.
4+
5+
; ModuleID = 'opaque_ptr.bc'
6+
source_filename = "llvm-link"
7+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
8+
target triple = "spir64-unknown-unknown"
9+
10+
%"class.sycl::_V1::ext::intel::esimd::simd.0" = type { %"class.sycl::_V1::ext::intel::esimd::detail::simd_obj_impl.1" }
11+
%"class.sycl::_V1::ext::intel::esimd::detail::simd_obj_impl.1" = type { <16 x float> }
12+
13+
define linkonce_odr dso_local spir_func void @foo(ptr addrspace(4) noalias sret(%"class.sycl::_V1::ext::intel::esimd::simd.0") align 128 %agg.result,
14+
ptr noundef byval(%"class.sycl::_V1::ext::intel::esimd::simd.0") align 128 %val) {
15+
; CHECK: [[ALLOCA:%.*]] = alloca <16 x float>, align 64
16+
; CHECK: [[CAST:%.*]] = addrspacecast ptr [[ALLOCA]] to ptr addrspace(4)
17+
; CHECK: call void @llvm.memcpy.p4.p4.i64(ptr addrspace(4) align 128 [[CAST]], ptr addrspace(4) align 128 [[ARGCAST:%.*]], i64 128, i1 false)
18+
; CHECK: [[LOAD:%.*]] = load <16 x float>, ptr [[ALLOCA]], align 64
19+
; CHECK: ret <16 x float> [[LOAD]]
20+
21+
entry:
22+
%val.ascast = addrspacecast ptr %val to ptr addrspace(4)
23+
call void @llvm.memcpy.p4.p4.i64(ptr addrspace(4) align 128 %agg.result, ptr addrspace(4) align 128 %val.ascast, i64 128, i1 false)
24+
ret void
25+
}
26+
27+
; Function Attrs: alwaysinline nocallback nofree nounwind willreturn memory(argmem: readwrite)
28+
declare void @llvm.memcpy.p4.p4.i64(ptr addrspace(4) noalias nocapture writeonly %0, ptr addrspace(4) noalias nocapture readonly %1, i64 %2, i1 immarg %3)

0 commit comments

Comments
 (0)