Skip to content

Commit f7d0459

Browse files
svenvhjsji
authored andcommitted
Add SPV_EXT_image_raw10_raw12 writer support (#2132)
Add basic LLVM-to-SPIR-V support for the SPV_EXT_image_raw10_raw12 extension, enabling the extension if constants from the extension are found in the LLVM IR. The extension adds 2 new return values for `OpImageQueryFormat`, which are integer constants that may appear anywhere in LLVM IR. Distinguishing between the extension's constants and arbitrary constants that happen to have the same value as the extension's constants is not possible in general. Hence this patch only covers some common use cases, where the result of `OpImageQueryFormat` is directly used in an integer comparison or switch instruction. Original commit: KhronosGroup/SPIRV-LLVM-Translator@8564fd4
1 parent 582f653 commit f7d0459

File tree

4 files changed

+114
-0
lines changed

4 files changed

+114
-0
lines changed

llvm-spirv/lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,14 @@
4848
#include "llvm/IR/IRBuilder.h"
4949
#include "llvm/IR/Instruction.h"
5050
#include "llvm/IR/Instructions.h"
51+
#include "llvm/IR/PatternMatch.h"
5152
#include "llvm/Support/Debug.h"
5253

5354
#include <algorithm>
5455
#include <set>
5556

5657
using namespace llvm;
58+
using namespace PatternMatch;
5759
using namespace SPIRV;
5860
using namespace OCLUtil;
5961

@@ -1355,10 +1357,56 @@ void OCLToSPIRVBase::visitCallScalToVec(CallInst *CI, StringRef MangledName,
13551357
});
13561358
}
13571359

1360+
namespace {
1361+
// Return true if any users of the CallInst use any of the constants
1362+
// introduced by the SPV_EXT_image_raw10_raw12 extension.
1363+
bool usesSpvExtImageRaw10Raw12Constants(const CallInst *CI) {
1364+
const std::array ExtConstants{
1365+
OCLImageChannelDataTypeOffset + ImageChannelDataTypeUnsignedIntRaw10EXT,
1366+
OCLImageChannelDataTypeOffset + ImageChannelDataTypeUnsignedIntRaw12EXT};
1367+
1368+
// The return values for `OpImageQueryFormat` added by the extension are
1369+
// integer constants that may appear anywhere in LLVM IR. Only detect some
1370+
// common use patterns here.
1371+
for (auto *U : CI->users()) {
1372+
for (auto C : ExtConstants) {
1373+
ICmpInst::Predicate Pred;
1374+
if (match(U, m_c_ICmp(Pred, m_Value(), m_SpecificInt(C)))) {
1375+
return true;
1376+
}
1377+
if (auto *Switch = dyn_cast<SwitchInst>(U)) {
1378+
if (any_of(Switch->cases(), [C](const auto &Case) {
1379+
return Case.getCaseValue()->equalsInt(C);
1380+
})) {
1381+
return true;
1382+
}
1383+
}
1384+
}
1385+
}
1386+
return false;
1387+
}
1388+
} // anonymous namespace
1389+
13581390
void OCLToSPIRVBase::visitCallGetImageChannel(CallInst *CI,
13591391
StringRef DemangledName,
13601392
unsigned int Offset) {
13611393
assert(CI->getCalledFunction() && "Unexpected indirect call");
1394+
1395+
if (Offset == OCLImageChannelDataTypeOffset) {
1396+
// See if any of the SPV_EXT_image_raw10_raw12 constants are used, and
1397+
// add the extension if not already there.
1398+
if (usesSpvExtImageRaw10Raw12Constants(CI)) {
1399+
const char *ExtStr = "SPV_EXT_image_raw10_raw12";
1400+
NamedMDNode *NMD = M->getOrInsertNamedMetadata(kSPIRVMD::Extension);
1401+
if (none_of(NMD->operands(), [ExtStr](MDNode *N) {
1402+
return N->getOperand(0).equalsStr(ExtStr);
1403+
})) {
1404+
MDString *MDS = MDString::get(*Ctx, ExtStr);
1405+
NMD->addOperand(MDNode::get(*Ctx, MDS));
1406+
}
1407+
}
1408+
}
1409+
13621410
Op OC = OpNop;
13631411
OCLSPIRVBuiltinMap::find(DemangledName.str(), &OC);
13641412
mutateCallInst(CI, OC).changeReturnType(
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: %clang_cc1 -triple spir-unknown-unknown -O1 -cl-std=CL2.0 -emit-llvm-bc %s -o %t.bc
2+
// RUN: llvm-spirv --spirv-ext=+SPV_EXT_image_raw10_raw12 %t.bc -o %t.spv
3+
// RUN: llvm-spirv --spirv-ext=+SPV_EXT_image_raw10_raw12 %t.spv -to-text -o %t.spt
4+
// RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
// CHECK-SPIRV-NOT: Extension "SPV_EXT_image_raw10_raw12"
7+
8+
// Test that use of constant values equal to the extension's constants do not enable the extension.
9+
10+
kernel void test_raw1012(global int *dst, int value) {
11+
switch (value) {
12+
case 0x10E3: // same value as CLK_UNSIGNED_INT_RAW10_EXT
13+
*dst = 10;
14+
break;
15+
case 0x10E4: // same value as CLK_UNSIGNED_INT_RAW12_EXT
16+
*dst = 12;
17+
break;
18+
}
19+
20+
if (value==0x10E3 || value==0x10E4) {
21+
*dst = 1012;
22+
}
23+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: %clang_cc1 -triple spir-unknown-unknown -O1 -cl-std=CL2.0 -fdeclare-opencl-builtins -finclude-default-header -emit-llvm-bc %s -o %t.bc
2+
// RUN: llvm-spirv --spirv-ext=+SPV_EXT_image_raw10_raw12 %t.bc -o %t.spv
3+
// RUN: llvm-spirv --spirv-ext=+SPV_EXT_image_raw10_raw12 %t.spv -to-text -o %t.spt
4+
// RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
// RUN: llvm-spirv -r %t.spv -o %t.rev.bc
7+
// RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefixes=CHECK-COMMON,CHECK-LLVM
8+
// RUN: llvm-spirv -r %t.spv --spirv-target-env=SPV-IR -o %t.rev.bc
9+
// RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefixes=CHECK-COMMON,CHECK-SPV-IR
10+
11+
// RUN: not llvm-spirv --spirv-ext=-SPV_EXT_image_raw10_raw12 %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-EXT-OFF
12+
// CHECK-EXT-OFF: Feature requires the following SPIR-V extension
13+
// CHECK-EXT-OFF-NEXT: SPV_EXT_image_raw10_raw12
14+
15+
// CHECK-SPIRV: Extension "SPV_EXT_image_raw10_raw12"
16+
17+
// CHECK-COMMON: test_raw1012
18+
// CHECK-LLVM: _Z27get_image_channel_data_type14ocl_image2d_ro
19+
// CHECK-SPV-IR: call spir_func i32 @_Z24__spirv_ImageQueryFormatPU3AS133__spirv_Image__void_1_0_0_0_0_0_0
20+
// CHECK-COMMON: switch i32
21+
// CHECK-COMMON: i32 4323,
22+
// CHECK-COMMON: i32 4324,
23+
// CHECK-COMMON: icmp eq i32 %{{.*}}, 4323
24+
// CHECK-COMMON: icmp eq i32 %{{.*}}, 4324
25+
26+
kernel void test_raw1012(global int *dst, read_only image2d_t img) {
27+
switch (get_image_channel_data_type(img)) {
28+
case CLK_SNORM_INT8:
29+
*dst = 8;
30+
break;
31+
case CLK_UNSIGNED_INT_RAW10_EXT:
32+
*dst = 10;
33+
break;
34+
case CLK_UNSIGNED_INT_RAW12_EXT:
35+
*dst = 12;
36+
break;
37+
}
38+
39+
if (get_image_channel_data_type(img) == CLK_UNSIGNED_INT_RAW10_EXT)
40+
*dst = 1010;
41+
else if (CLK_UNSIGNED_INT_RAW12_EXT == get_image_channel_data_type(img))
42+
*dst = 1212;
43+
}

0 commit comments

Comments
 (0)