Skip to content

Commit dd8b217

Browse files
authored
Add JointMatrixGetElementCoordINTEL instruction (#1834)
The instruction returns (Row, Column) coordinate of dynamically selected element of a matrix Updated version of the spec is here intel/llvm#8175 Instruction correctness checks will be added later among non-backward compatible changes. Signed-off-by: Sidorov, Dmitry [email protected]
1 parent daad382 commit dd8b217

File tree

6 files changed

+42
-7
lines changed

6 files changed

+42
-7
lines changed

lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
200200
{CapabilitySubgroupAvcMotionEstimationINTEL});
201201
ADD_VEC_INIT(CapabilitySubgroupAvcMotionEstimationChromaINTEL,
202202
{CapabilitySubgroupAvcMotionEstimationIntraINTEL});
203+
ADD_VEC_INIT(internal::CapabilityJointMatrixWIInstructionsINTEL,
204+
{internal::CapabilityJointMatrixINTEL});
203205
}
204206

205207
template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3330,9 +3330,24 @@ _SPIRV_OP(JointMatrixMad, true, 7)
33303330
_SPIRV_OP(JointMatrixSUMad, true, 7)
33313331
_SPIRV_OP(JointMatrixUSMad, true, 7)
33323332
_SPIRV_OP(JointMatrixUUMad, true, 7)
3333+
// TODO: move to SPIRVJointMatrixINTELWorkItemInst
33333334
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
33343335
#undef _SPIRV_OP
33353336

3337+
class SPIRVJointMatrixINTELWorkItemInst : public SPIRVJointMatrixINTELInstBase {
3338+
protected:
3339+
SPIRVCapVec getRequiredCapability() const override {
3340+
return getVec(internal::CapabilityJointMatrixWIInstructionsINTEL);
3341+
}
3342+
};
3343+
3344+
#define _SPIRV_OP(x, ...) \
3345+
typedef SPIRVInstTemplate<SPIRVJointMatrixINTELWorkItemInst, \
3346+
internal::Op##x##INTEL, __VA_ARGS__> \
3347+
SPIRV##x##INTEL;
3348+
_SPIRV_OP(JointMatrixGetElementCoord, true, 5)
3349+
#undef _SPIRV_OP
3350+
33363351
class SPIRVSplitBarrierINTELBase : public SPIRVInstTemplateBase {
33373352
protected:
33383353
SPIRVCapVec getRequiredCapability() const override {

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,8 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
631631
add(internal::CapabilityMaskedGatherScatterINTEL, "MaskedGatherScatterINTEL");
632632
add(internal::CapabilityTensorFloat32ConversionINTEL,
633633
"TensorFloat32ConversionINTEL");
634+
add(internal::CapabilityJointMatrixWIInstructionsINTEL,
635+
"JointMatrixWIInstructionsINTEL");
634636
}
635637
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
636638

lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ _SPIRV_OP_INTERNAL(JointMatrixUSMadINTEL, internal::OpJointMatrixUSMadINTEL)
1414
_SPIRV_OP_INTERNAL(JointMatrixUUMadINTEL, internal::OpJointMatrixUUMadINTEL)
1515
_SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL,
1616
internal::OpJointMatrixWorkItemLengthINTEL)
17+
_SPIRV_OP_INTERNAL(JointMatrixGetElementCoordINTEL,
18+
internal::OpJointMatrixGetElementCoordINTEL)
1719
_SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
1820
_SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL)
1921
_SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL)

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ enum InternalOp {
5252
IOpConvertFToTF32INTEL = 6426,
5353
IOpMaskedGatherINTEL = 6428,
5454
IOpMaskedScatterINTEL = 6429,
55+
IOpJointMatrixGetElementCoordINTEL = 6440,
5556
IOpPrev = OpMax - 2,
5657
IOpForward
5758
};
@@ -76,7 +77,8 @@ enum InternalCapability {
7677
ICapGlobalVariableDecorationsINTEL = 6146,
7778
ICapabilityComplexFloatMulDivINTEL = 6414,
7879
ICapabilityTensorFloat32ConversionINTEL = 6425,
79-
ICapabilityMaskedGatherScatterINTEL = 6427
80+
ICapabilityMaskedGatherScatterINTEL = 6427,
81+
ICapabilityJointMatrixWIInstructionsINTEL = 6435
8082
};
8183

8284
enum InternalFunctionControlMask { IFunctionControlOptNoneINTELMask = 0x10000 };
@@ -104,6 +106,7 @@ enum InternalBuiltIn {
104106

105107
#define _SPIRV_OP(x, y) constexpr x x##y = static_cast<x>(I##x##y);
106108
_SPIRV_OP(Capability, JointMatrixINTEL)
109+
_SPIRV_OP(Capability, JointMatrixWIInstructionsINTEL)
107110
_SPIRV_OP(Op, TypeJointMatrixINTEL)
108111
_SPIRV_OP(Op, JointMatrixLoadINTEL)
109112
_SPIRV_OP(Op, JointMatrixStoreINTEL)
@@ -112,6 +115,8 @@ _SPIRV_OP(Op, JointMatrixSUMadINTEL)
112115
_SPIRV_OP(Op, JointMatrixUSMadINTEL)
113116
_SPIRV_OP(Op, JointMatrixUUMadINTEL)
114117
_SPIRV_OP(Op, JointMatrixWorkItemLengthINTEL)
118+
_SPIRV_OP(Op, JointMatrixGetElementCoordINTEL)
119+
115120
_SPIRV_OP(Capability, HWThreadQueryINTEL)
116121
_SPIRV_OP(BuiltIn, SubDeviceIDINTEL)
117122
_SPIRV_OP(BuiltIn, GlobalHWThreadIDINTEL)

test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_element.ll

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,26 @@
55
; RUN: llvm-spirv -r -emit-opaque-pointers %t.spv -o %t.rev.bc
66
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
77

8-
; CHECK-SPIRV: Capability JointMatrixINTEL
9-
; CHECK-SPIRV: Extension "SPV_INTEL_joint_matrix"
10-
; CHECK-SPIRV: TypeInt [[#TypeInt:]] 64
11-
; CHECK-SPIRV: TypeFloat [[#TypeFloat:]] 32
12-
; CHECK-SPIRV: TypeJointMatrixINTEL [[#TypeMatrix:]] [[#TypeFloat]] [[#]] [[#]] [[#]] [[#]]
8+
; CHECK-SPIRV-DAG: Capability JointMatrixINTEL
9+
; CHECK-SPIRV-DAG: Capability JointMatrixWIInstructionsINTEL
10+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix"
11+
; CHECK-SPIRV-DAG: TypeInt [[#TypeInt32:]] 32
12+
; CHECK-SPIRV-DAG: TypeInt [[#TypeInt64:]] 64
13+
; CHECK-SPIRV-DAG: TypeFloat [[#TypeFloat:]] 32
14+
; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#TypeMatrix:]] [[#TypeFloat]] [[#]] [[#]] [[#]] [[#]]
15+
; CHECK-SPIRV-DAG: TypeVector [[#TypeVec:]] [[#TypeInt32]] 2
1316
; CHECK-SPIRV: Phi [[#TypeMatrix]] [[#Matrix:]]
14-
; CHECK-SPIRV: JointMatrixWorkItemLengthINTEL [[#TypeInt]] [[#]] [[#Matrix]]
17+
; CHECK-SPIRV: JointMatrixWorkItemLengthINTEL [[#TypeInt64]] [[#]] [[#Matrix]]
1518
; CHECK-SPIRV: VectorExtractDynamic [[#TypeFloat]] [[#]] [[#Matrix]] [[#Index:]]
1619
; CHECK-SPIRV: FMul [[#TypeFloat]] [[#NewVal:]] [[#]] [[#]]
1720
; CHECK-SPIRV: VectorInsertDynamic [[#TypeMatrix]] [[#]] [[#Matrix]] [[#NewVal]] [[#Index]]
21+
; CHECK-SPIRV: JointMatrixGetElementCoordINTEL [[#TypeVec]] [[#]] [[#Matrix]] [[#Index]]
1822

1923
; CHECK-LLVM: [[Length:%.*]] = call spir_func i64 @_Z38__spirv_JointMatrixWorkItemLengthINTELPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3(ptr addrspace(1) [[Matrix:%.*]])
2024
; CHECK-LLVM: [[Elem:%.*]] = call spir_func float @_Z28__spirv_VectorExtractDynamicPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3l(ptr addrspace(1) [[Matrix]], i64 [[Index:%.*]])
2125
; CHECK-LLVM: [[NewVal:%.*]] = fmul float [[Elem]], 5.000000e+00
2226
; CHECK-LLVM: {{%.*}} = call spir_func ptr addrspace(1) @_Z27__spirv_VectorInsertDynamicPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3fl(ptr addrspace(1) [[Matrix]], float [[NewVal]], i64 [[Index]])
27+
; CHECK-LLVM: {{%.*}} = call spir_func <2 x i32> @_Z39__spirv_JointMatrixGetElementCoordINTELPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3l(ptr addrspace(1) [[Matrix]], i64 [[Index]])
2328

2429
source_filename = "/work/tmp/matrix-slice.cpp"
2530
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
@@ -69,6 +74,7 @@ for.body.i: ; preds = %for.cond.i
6974
%call.i.i = tail call spir_func float @_Z28__spirv_VectorExtractDynamicIfLm16ELm16ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EmET_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEET4_(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* %A.sroa.0.0.i, i64 %conv.i) #2
7075
%mul.i.i = fmul float %call.i.i, 5.000000e+00
7176
%call5.i.i = tail call spir_func %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* @_Z27__spirv_VectorInsertDynamicIfLm16ELm16ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EmEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT2_EXT3_EEES7_T4_S5_(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* %A.sroa.0.0.i, float %mul.i.i, i64 %conv.i) #2
77+
%call6 = tail call spir_func <2 x i32> @_Z39__spirv_JointMatrixGetElementCoordINTELIaLm8ELm32ELN5__spv9MatrixUseE0ELNS0_12MatrixLayoutE0ELNS0_5Scope4FlagE3EEDv2_jPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEm(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* %A.sroa.0.0.i, i64 %conv.i) #2
7278
%inc.i = add nuw nsw i32 %i.0.i, 1
7379
br label %for.cond.i, !llvm.loop !7
7480

@@ -92,6 +98,9 @@ declare dso_local spir_func %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4
9298
; Function Attrs: convergent
9399
declare dso_local spir_func void @_Z29__spirv_JointMatrixStoreINTELIfLm16ELm16ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEEmS1_S3_i(float addrspace(4)*, %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)*, i64, i32, i32, i32) local_unnamed_addr #1
94100

101+
; Function Attrs: convergent
102+
declare dso_local spir_func <2 x i32> @_Z39__spirv_JointMatrixGetElementCoordINTELIaLm8ELm32ELN5__spv9MatrixUseE0ELNS0_12MatrixLayoutE0ELNS0_5Scope4FlagE3EEDv2_jPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEm(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)*, i64) #2
103+
95104
attributes #0 = { convergent norecurse "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="/work/tmp/matrix-slice.cpp" "uniform-work-group-size"="true" }
96105
attributes #1 = { convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
97106
attributes #2 = { convergent }

0 commit comments

Comments
 (0)