Skip to content

[SPIR-V] Fix vloadn OpenCL builtin lowering #81148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ struct VectorLoadStoreBuiltin {
StringRef Name;
InstructionSet::InstructionSet Set;
uint32_t Number;
uint32_t ElementCount;
bool IsRounded;
FPRoundingMode::FPRoundingMode RoundingMode;
};
Expand Down Expand Up @@ -1851,6 +1852,7 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
.addImm(Builtin->Number);
for (auto Argument : Call->Arguments)
MIB.addUse(Argument);
MIB.addImm(Builtin->ElementCount);

// Rounding mode should be passed as a last argument in the MI for builtins
// like "vstorea_halfn_r".
Expand Down
16 changes: 11 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -1046,18 +1046,24 @@ class VectorLoadStoreBuiltin<string name, InstructionSet set, int number> {
string Name = name;
InstructionSet Set = set;
bits<32> Number = number;
bits<32> ElementCount = !cond(!not(!eq(!find(name, "2"), -1)) : 2,
!not(!eq(!find(name, "3"), -1)) : 3,
!not(!eq(!find(name, "4"), -1)) : 4,
!not(!eq(!find(name, "8"), -1)) : 8,
!not(!eq(!find(name, "16"), -1)) : 16,
true : 1);
bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
!not(!eq(!find(name, "_rtn"), -1)) : RTN,
true : RTE);
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
!not(!eq(!find(name, "_rtn"), -1)) : RTN,
true : RTE);
}

// Table gathering all the vector data load/store builtins.
def VectorLoadStoreBuiltins : GenericTable {
let FilterClass = "VectorLoadStoreBuiltin";
let Fields = ["Name", "Set", "Number", "IsRounded", "RoundingMode"];
let Fields = ["Name", "Set", "Number", "ElementCount", "IsRounded", "RoundingMode"];
string TypeOf_Set = "InstructionSet";
string TypeOf_RoundingMode = "FPRoundingMode";
}
Expand Down
27 changes: 13 additions & 14 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,25 +290,14 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
Value *Pointer;
Type *ExpectedElementType;
unsigned OperandToReplace;
bool AllowCastingToChar = false;

StoreInst *SI = dyn_cast<StoreInst>(I);
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
SI->getValueOperand()->getType()->isPointerTy() &&
isa<Argument>(SI->getValueOperand())) {
Argument *Arg = cast<Argument>(SI->getValueOperand());
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
if (!ArgType || ArgType->getString().starts_with("uchar*"))
return;

// Handle special case when StoreInst's value operand is a kernel argument
// of a pointer type. Since these arguments could have either a basic
// element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast
// the StoreInst's value operand to default pointer element type (i8).
Pointer = Arg;
Pointer = SI->getValueOperand();
ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
OperandToReplace = 0;
AllowCastingToChar = true;
} else if (SI) {
Pointer = SI->getPointerOperand();
ExpectedElementType = SI->getValueOperand()->getType();
Expand Down Expand Up @@ -390,10 +379,20 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
}

// Do not emit spv_ptrcast if it would cast to the default pointer element
// type (i8) of the same address space.
if (ExpectedElementType->isIntegerTy(8) && !AllowCastingToChar)
// type (i8) of the same address space. In case of OpenCL kernels, make sure
// i8 is the pointer element type defined for the given kernel argument.
if (ExpectedElementType->isIntegerTy(8) &&
F->getCallingConv() != CallingConv::SPIR_KERNEL)
return;

Argument *Arg = dyn_cast<Argument>(Pointer);
if (ExpectedElementType->isIntegerTy(8) &&
F->getCallingConv() == CallingConv::SPIR_KERNEL && Arg) {
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
if (ArgType && ArgType->getString().starts_with("uchar*"))
return;
}

// If this would be the first spv_ptrcast, the pointer's defining instruction
// requires spv_assign_ptr_type and does not already have one, do not emit
// spv_ptrcast and emit spv_assign_ptr_type instead.
Expand Down
95 changes: 0 additions & 95 deletions llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll

This file was deleted.

40 changes: 40 additions & 0 deletions llvm/test/CodeGen/SPIRV/opencl/vload2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not running spirv-val on this too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment below. It looks like the SPIR-V validator is doing additional checks on the arguments of those OpExtInst calls and those are not valid here. I want this test to only check the vloadn builtin name resolution to keep the tests as small as possible.

; This test only intends to check the vloadn builtin name resolution.
; The calls to the OpenCL builtins are not valid and will not pass SPIR-V validation.

; This test only intends to check the vloadn builtin name resolution.
; The calls to the OpenCL builtins are not valid and will not pass SPIR-V validation.

; CHECK-DAG: %[[#IMPORT:]] = OpExtInstImport "OpenCL.std"

; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
; CHECK-DAG: %[[#INT16:]] = OpTypeInt 16 0
; CHECK-DAG: %[[#INT32:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
; CHECK-DAG: %[[#FLOAT:]] = OpTypeFloat 32
; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
; CHECK-DAG: %[[#VINT16:]] = OpTypeVector %[[#INT16]] 2
; CHECK-DAG: %[[#VINT32:]] = OpTypeVector %[[#INT32]] 2
; CHECK-DAG: %[[#VINT64:]] = OpTypeVector %[[#INT64]] 2
; CHECK-DAG: %[[#VFLOAT:]] = OpTypeVector %[[#FLOAT]] 2
; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer CrossWorkgroup %[[#INT8]]

; CHECK: %[[#OFFSET:]] = OpFunctionParameter %[[#INT64]]
; CHECK: %[[#ADDRESS:]] = OpFunctionParameter %[[#PTRINT8]]

define spir_kernel void @test_fn(i64 %offset, ptr addrspace(1) %address) {
; CHECK: %[[#]] = OpExtInst %[[#VINT8]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
%call1 = call spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64 %offset, ptr addrspace(1) %address)
; CHECK: %[[#]] = OpExtInst %[[#VINT16]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
%call2 = call spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64 %offset, ptr addrspace(1) %address)
; CHECK: %[[#]] = OpExtInst %[[#VINT32]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
%call3 = call spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64 %offset, ptr addrspace(1) %address)
; CHECK: %[[#]] = OpExtInst %[[#VINT64]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
%call4 = call spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64 %offset, ptr addrspace(1) %address)
; CHECK: %[[#]] = OpExtInst %[[#VFLOAT]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
%call5 = call spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64 %offset, ptr addrspace(1) %address)
ret void
}

declare spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64, ptr addrspace(1))
declare spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64, ptr addrspace(1))
declare spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64, ptr addrspace(1))
declare spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64, ptr addrspace(1))
declare spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64, ptr addrspace(1))
31 changes: 31 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/getelementptr-kernel-arg-char.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer Workgroup %[[#INT8]]
; CHECK-DAG: %[[#PTRVINT8:]] = OpTypePointer Workgroup %[[#VINT8]]
; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#INT64]] 1

; CHECK: %[[#PARAM1:]] = OpFunctionParameter %[[#PTRVINT8]]
define spir_kernel void @test1(ptr addrspace(3) %address) !kernel_arg_type !1 {
; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#PTRINT8]] %[[#PARAM1]]
; CHECK: %[[#]] = OpInBoundsPtrAccessChain %[[#PTRINT8]] %[[#BITCAST1]] %[[#CONST]]
%cast = bitcast ptr addrspace(3) %address to ptr addrspace(3)
%gep = getelementptr inbounds i8, ptr addrspace(3) %cast, i64 1
ret void
}

; CHECK: %[[#PARAM2:]] = OpFunctionParameter %[[#PTRVINT8]]
define spir_kernel void @test2(ptr addrspace(3) %address) !kernel_arg_type !1 {
; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#PTRINT8]] %[[#PARAM2]]
; CHECK: %[[#]] = OpInBoundsPtrAccessChain %[[#PTRINT8]] %[[#BITCAST2]] %[[#CONST]]
%gep = getelementptr inbounds i8, ptr addrspace(3) %address, i64 1
ret void
}

declare spir_func <2 x i8> @_Z6vload2mPU3AS3Kc(i64, ptr addrspace(3))

!1 = !{!"char2*"}