Skip to content

Commit 86837df

Browse files
ensure that correct types are applied to virtual registers which were used as arguments in call lowering and so caused early definition of SPIR-V types
1 parent 02bf2fd commit 86837df

File tree

4 files changed

+121
-7
lines changed

4 files changed

+121
-7
lines changed

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -545,16 +545,29 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
545545
Register ArgReg = Arg.Regs[0];
546546
ArgVRegs.push_back(ArgReg);
547547
SPIRVType *SpvType = GR->getSPIRVTypeForVReg(ArgReg);
548-
// If Arg.Ty is an untyped pointer (i.e., ptr [addrspace(...)]) we should
549-
// wait with setting the type for the virtual register until pre-legalizer
550-
// step when we access @llvm.spv.assign.ptr.type.p...(...)'s info.
551-
if (!SpvType && !isUntypedPointerTy(Arg.Ty)) {
552-
SpvType = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
553-
GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF);
548+
if (!SpvType) {
549+
Type *ArgTy = nullptr;
550+
if (auto *PtrArgTy = dyn_cast<PointerType>(Arg.Ty)) {
551+
// If Arg.Ty is an untyped pointer (i.e., ptr [addrspace(...)]) and we
552+
// don't have access to original value in LLVM IR or info about
553+
// deduced pointee type, then we should wait with setting the type for
554+
// the virtual register until pre-legalizer step when we access
555+
// @llvm.spv.assign.ptr.type.p...(...)'s info.
556+
if (Arg.OrigValue)
557+
if (Type *ElemTy = GR->findDeducedElementType(Arg.OrigValue))
558+
ArgTy = TypedPointerType::get(ElemTy, PtrArgTy->getAddressSpace());
559+
} else {
560+
ArgTy = Arg.Ty;
561+
}
562+
if (ArgTy) {
563+
SpvType = GR->getOrCreateSPIRVType(ArgTy, MIRBuilder);
564+
GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF);
565+
}
554566
}
555567
if (!MRI->getRegClassOrNull(ArgReg)) {
556568
// Either we have SpvType created, or Arg.Ty is an untyped pointer and
557-
// we know its virtual register's class and type.
569+
// we know its virtual register's class and type even if we don't know
570+
// pointee type.
558571
MRI->setRegClass(ArgReg, SpvType ? GR->getRegClass(SpvType)
559572
: &SPIRV::pIDRegClass);
560573
MRI->setType(

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,8 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
12191219
SmallVector<Value *, 2> Args = {Pointer, VMD, B.getInt32(AddressSpace)};
12201220
auto *PtrCastI = B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
12211221
I->setOperand(OperandToReplace, PtrCastI);
1222+
// We need to set up a pointee type for the newly created spv_ptrcast.
1223+
buildAssignPtr(B, ExpectedElementType, PtrCastI);
12221224
}
12231225

12241226
void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
; The goal of the test case is to ensure that correct types are applied to virtual registers
2+
; which were used as arguments in call lowering and so caused early definition of SPIR-V types.
3+
4+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
5+
6+
%t_id = type { %t_arr }
7+
%t_arr = type { [1 x i64] }
8+
%t_bf16 = type { i16 }
9+
10+
define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) align 4 %_arg_ERR, ptr byval(%t_id) align 8 %_arg_ERR3) {
11+
entry:
12+
%FloatArray.i = alloca [4 x float], align 4
13+
%BF16Array.i = alloca [4 x %t_bf16], align 2
14+
%0 = load i64, ptr %_arg_ERR3, align 8
15+
%add.ptr.i = getelementptr inbounds i32, ptr addrspace(1) %_arg_ERR, i64 %0
16+
%FloatArray.ascast.i = addrspacecast ptr %FloatArray.i to ptr addrspace(4)
17+
%BF16Array.ascast.i = addrspacecast ptr %BF16Array.i to ptr addrspace(4)
18+
call spir_func void @__devicelib_ConvertFToBF16INTELVec4(ptr addrspace(4) %FloatArray.ascast.i, ptr addrspace(4) %BF16Array.ascast.i)
19+
br label %for.cond.i
20+
21+
for.cond.i: ; preds = %for.inc.i, %entry
22+
%lsr.iv1 = phi ptr [ %scevgep2, %for.inc.i ], [ %FloatArray.i, %entry ]
23+
%lsr.iv = phi ptr addrspace(4) [ %scevgep, %for.inc.i ], [ %BF16Array.ascast.i, %entry ]
24+
%i.0.i = phi i32 [ 0, %entry ], [ %inc.i, %for.inc.i ]
25+
%cmp.i = icmp ult i32 %i.0.i, 4
26+
br i1 %cmp.i, label %for.body.i, label %exit
27+
28+
for.body.i: ; preds = %for.cond.i
29+
%1 = load float, ptr %lsr.iv1, align 4
30+
%call.i.i = call spir_func float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) align 2 dereferenceable(2) %lsr.iv)
31+
%cmp5.i = fcmp une float %1, %call.i.i
32+
br i1 %cmp5.i, label %if.then.i, label %for.inc.i
33+
34+
if.then.i: ; preds = %for.body.i
35+
store i32 1, ptr addrspace(1) %add.ptr.i, align 4
36+
br label %for.inc.i
37+
38+
for.inc.i: ; preds = %if.then.i, %for.body.i
39+
%inc.i = add nuw nsw i32 %i.0.i, 1
40+
%scevgep = getelementptr i8, ptr addrspace(4) %lsr.iv, i64 2
41+
%scevgep2 = getelementptr i8, ptr %lsr.iv1, i64 4
42+
br label %for.cond.i
43+
44+
exit: ; preds = %for.cond.i
45+
ret void
46+
}
47+
48+
declare void @llvm.memcpy.p0.p1.i64(ptr noalias nocapture writeonly, ptr addrspace(1) noalias nocapture readonly, i64, i1 immarg)
49+
declare dso_local spir_func void @__devicelib_ConvertFToBF16INTELVec4(ptr addrspace(4), ptr addrspace(4))
50+
declare dso_local spir_func float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) align 2 dereferenceable(2))
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
; The goal of the test case is to ensure that correct types are applied to virtual registers
2+
; which were used as arguments in call lowering and so caused early definition of SPIR-V types.
3+
4+
; RUN: %if spirv-tools %{ llc -O2 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
5+
6+
%t_id = type { %t_arr }
7+
%t_arr = type { [1 x i64] }
8+
%t_bf16 = type { i16 }
9+
10+
define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) align 4 %_arg_ERR, ptr byval(%t_id) align 8 %_arg_ERR3) {
11+
entry:
12+
%FloatArray.i = alloca [4 x float], align 4
13+
%BF16Array.i = alloca [4 x %t_bf16], align 2
14+
%0 = load i64, ptr %_arg_ERR3, align 8
15+
%add.ptr.i = getelementptr inbounds i32, ptr addrspace(1) %_arg_ERR, i64 %0
16+
%FloatArray.ascast.i = addrspacecast ptr %FloatArray.i to ptr addrspace(4)
17+
%BF16Array.ascast.i = addrspacecast ptr %BF16Array.i to ptr addrspace(4)
18+
call spir_func void @__devicelib_ConvertFToBF16INTELVec4(ptr addrspace(4) %FloatArray.ascast.i, ptr addrspace(4) %BF16Array.ascast.i)
19+
br label %for.cond.i
20+
21+
for.cond.i: ; preds = %for.inc.i, %entry
22+
%i.0.i = phi i32 [ 0, %entry ], [ %inc.i, %for.inc.i ]
23+
%cmp.i = icmp ult i32 %i.0.i, 4
24+
br i1 %cmp.i, label %for.body.i, label %exit
25+
26+
for.body.i: ; preds = %for.cond.i
27+
%idxprom.i = zext nneg i32 %i.0.i to i64
28+
%arrayidx.i = getelementptr inbounds [4 x float], ptr %FloatArray.i, i64 0, i64 %idxprom.i
29+
%1 = load float, ptr %arrayidx.i, align 4
30+
%arrayidx4.i = getelementptr inbounds [4 x %t_bf16], ptr addrspace(4) %BF16Array.ascast.i, i64 0, i64 %idxprom.i
31+
%call.i.i = call spir_func float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) align 2 dereferenceable(2) %arrayidx4.i)
32+
%cmp5.i = fcmp une float %1, %call.i.i
33+
br i1 %cmp5.i, label %if.then.i, label %for.inc.i
34+
35+
if.then.i: ; preds = %for.body.i
36+
store i32 1, ptr addrspace(1) %add.ptr.i, align 4
37+
br label %for.inc.i
38+
39+
for.inc.i: ; preds = %if.then.i, %for.body.i
40+
%inc.i = add nuw nsw i32 %i.0.i, 1
41+
br label %for.cond.i
42+
43+
exit: ; preds = %for.cond.i
44+
ret void
45+
}
46+
47+
declare void @llvm.memcpy.p0.p1.i64(ptr noalias nocapture writeonly, ptr addrspace(1) noalias nocapture readonly, i64, i1 immarg)
48+
declare dso_local spir_func void @__devicelib_ConvertFToBF16INTELVec4(ptr addrspace(4), ptr addrspace(4))
49+
declare dso_local spir_func float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) align 2 dereferenceable(2))

0 commit comments

Comments
 (0)