Skip to content

Commit cb5ea20

Browse files
YuriPlyakhinigcbot
authored andcommitted
Optimize SYCL joint_matrix_apply lowering for accumulator 32x64
1. Optimize resolving of slice extract and insert for accumulator 32x64 to use GEP/Load/Store for accessing/updating matrix elements instead of extracting vectors from arrays and composing new arrays. 2. Make sure loop used inside joint_matrix_apply implementation is always fully unrolled.
1 parent 6e66432 commit cb5ea20

File tree

4 files changed

+85
-55
lines changed

4 files changed

+85
-55
lines changed

IGC/Compiler/GenTTI.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,21 @@ namespace llvm {
159159
#endif
160160
)
161161
{
162+
// Always unroll joint_matrix_apply loop
163+
for (auto BB : L->blocks())
164+
{
165+
for (auto &I : *BB)
166+
{
167+
if (auto *MD = I.getMetadata("joint_matrix_apply"))
168+
{
169+
UP.Threshold = UINT_MAX;
170+
UP.UpperBound = true;
171+
UP.Force = true;
172+
return;
173+
}
174+
}
175+
}
176+
162177
unsigned LoopUnrollThreshold = ctx->m_DriverInfo.GetLoopUnrollThreshold();
163178

164179
// override the LoopUnrollThreshold if the registry key is set

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.cpp

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ bool JointMatrixFuncsResolutionPass::runOnFunction(Function& F)
299299
ResolvedValues.clear();
300300
ResolvedTypes.clear();
301301
InstsToErase.clear();
302+
MatrixAllocas.clear();
302303
m_SIMDSize = 0;
303304

304305
// Use reverse post order traversal to reduce level or recursion
@@ -1876,6 +1877,30 @@ static Value *mergeComponentToPackedValue(BuilderT *builder, Value *value, Value
18761877
return builder->CreateOr(value, component);
18771878
}
18781879

1880+
// Gets pointer to element to process in joint_matrix_apply loop for Accumulator 32x64
1881+
// Also updates MatPtr to point to alloca of [2 x <float x 64>] used inside joint_matrix_apply loop
1882+
Value *JointMatrixFuncsResolutionPass::getAcc32x64ElementPtr(CallInst *CI, Value *matrix, Value *index, IRBuilder<> *builder, Value **MatPtr) {
1883+
if (LoadInst *loadInst = dyn_cast<LoadInst>(matrix)) {
1884+
*MatPtr = Resolve(loadInst->getPointerOperand());
1885+
} else {
1886+
// Use existing alloca or create alloca in the entry node of the function
1887+
*MatPtr = MatrixAllocas[matrix];
1888+
if (!*MatPtr) {
1889+
builder->SetInsertPoint(&*CI->getFunction()->getEntryBlock().getFirstInsertionPt());
1890+
builder->SetCurrentDebugLocation(CI->getDebugLoc());
1891+
*MatPtr = builder->CreateAlloca(matrix->getType(), ADDRESS_SPACE_PRIVATE);
1892+
MatrixAllocas[matrix] = *MatPtr;
1893+
builder->SetInsertPoint(CI);
1894+
}
1895+
builder->CreateStore(matrix, *MatPtr);
1896+
}
1897+
1898+
Value *FloatPtr = builder->CreateBitCast(*MatPtr, builder->getFloatTy()->getPointerTo((*MatPtr)->getType()->getPointerAddressSpace()));
1899+
1900+
// create GEP to extract element by 'index' from 'matrix'
1901+
return builder->CreateGEP(builder->getFloatTy(), FloatPtr, index);
1902+
}
1903+
18791904
Value *JointMatrixFuncsResolutionPass::ResolveSliceInsert(CallInst *CI) {
18801905
Value *matrix = Resolve(CI->getArgOperand(0));
18811906
Value *component = CI->getArgOperand(1);
@@ -1906,18 +1931,10 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceInsert(CallInst *CI) {
19061931
// Special case Accumulator 32x64 is represented as [2 x <float x 64>].
19071932
if (isAccumulator32x64(desc))
19081933
{
1909-
// extract first or second half of array
1910-
Value *indexArray = builder.CreateICmpUGT(index, ConstantInt::get(index->getType(), 63)); // i1 0 or 1
1911-
Value *half0 = builder.CreateExtractValue(matrix, {0}, "matrix.slice.half0");
1912-
Value *half1 = builder.CreateExtractValue(matrix, {1}, "matrix.slice.half1");
1913-
Value *halfMatrix = builder.CreateSelect(indexArray, half1, half0, "matrix.slice.selected.half"); // <64 x float>
1914-
1915-
// insert new component to vector <64 x float> and then insert new vector to array of 2 vectors
1916-
Value* indexVec = builder.CreateURem(index, ConstantInt::get(index->getType(), 64)); // 0..63
1917-
slice = builder.CreateInsertElement(halfMatrix, component, indexVec);
1918-
Value *newHalf0 = builder.CreateSelect(indexArray, half0, slice);
1919-
Value *newHalf1 = builder.CreateSelect(indexArray, slice, half1);
1920-
slice = createPair(&builder, getAcc32x64HalfType(builder.getContext()), newHalf0, newHalf1);
1934+
Value *MatPtr = nullptr;
1935+
Value *ptrToElem = getAcc32x64ElementPtr(CI, matrix, index, &builder, &MatPtr);
1936+
builder.CreateStore(component, ptrToElem);
1937+
slice = builder.CreateLoad(matTy, MatPtr);
19211938
}
19221939
else if (dyn_cast<IGCLLVM::FixedVectorType>(matTy))
19231940
slice = builder.CreateInsertElement(matrix, component, index);
@@ -1942,17 +1959,9 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceExtract(CallInst *CI) {
19421959
Value *indexVec = index;
19431960
element = updateIndexAndCreateSliceExtract(&builder, matrix, &indexVec, desc.contribBitWidth, desc.bitWidth);
19441961
} else if (isAccumulator32x64(desc)) {
1945-
// Get index of which element of array to use: 0 or 1
1946-
Value* indexArray = builder.CreateICmpUGT(index, ConstantInt::get(index->getType(), 63));
1947-
1948-
// Select half that we need:
1949-
Value* half0 = builder.CreateExtractValue(matrix, {0}, "matrix.slice.half0");
1950-
Value* half1 = builder.CreateExtractValue(matrix, {1}, "matrix.slice.half1");
1951-
Value* halfMatrix = builder.CreateSelect(indexArray, half1, half0, "matrix.slice.selected.half");
1952-
1953-
// get index of element inside vector of 64 elements
1954-
Value* indexVec = builder.CreateURem(index, ConstantInt::get(index->getType(), 64)); // 0..63
1955-
element = updateIndexAndCreateSliceExtract(&builder, halfMatrix, &indexVec, desc.contribBitWidth, desc.bitWidth);
1962+
Value *MatPtr = nullptr;
1963+
Value *ptrToElem = getAcc32x64ElementPtr(CI, matrix, index, &builder, &MatPtr);
1964+
element = builder.CreateLoad(builder.getFloatTy(), ptrToElem);
19561965
}
19571966

19581967
// unpack element we need from packed value
@@ -1964,6 +1973,12 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceExtract(CallInst *CI) {
19641973
// being replaced has a half return type and the vectorElementType is i16
19651974
element = builder.CreateBitCast(element, CI->getType());
19661975

1976+
// Add metadata to mark this value as part of joint_matrix_apply loop
1977+
// It will be used in getUnrollingPreferences to make sure this loop is fully unrolled
1978+
Instruction* elementInst = cast<Instruction>(element);
1979+
MDNode* node = MDNode::get(CI->getContext(), ConstantAsMetadata::get(builder.getInt1(true)));
1980+
elementInst->setMetadata("joint_matrix_apply", node);
1981+
19671982
InstsToErase.insert(CI);
19681983
return element;
19691984
}

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ namespace IGC
6262
llvm::Value *ResolveFill(llvm::CallInst *CI);
6363
llvm::Instruction *ResolveFillChecked(llvm::CallInst *CI);
6464
llvm::Value *ResolveWILength(llvm::CallInst *CI);
65+
llvm::Value *getAcc32x64ElementPtr(llvm::CallInst *CI, llvm::Value *matrix, llvm::Value *index, llvm::IRBuilder<> *builder, llvm::Value **MatPtr);
6566
llvm::Value *ResolveSliceInsert(llvm::CallInst *CI);
6667
llvm::Value *ResolveSliceExtract(llvm::CallInst *CI);
6768
llvm::Instruction *ResolveGetCoord(llvm::CallInst *CI);
@@ -113,6 +114,7 @@ namespace IGC
113114

114115
llvm::ValueMap<llvm::Value *, llvm::Instruction *> PlaceholderInstructions;
115116
llvm::ValueMap<llvm::Value *, llvm::Value *> ResolvedValues;
117+
llvm::ValueMap<llvm::Value *, llvm::Value *> MatrixAllocas;
116118
std::unordered_map<llvm::Type *, llvm::Type *> ResolvedTypes;
117119
llvm::SmallPtrSet<llvm::Instruction *, 8> InstsToErase;
118120
// Maps function to it's kernel entry function

IGC/Compiler/tests/JointMatrixFuncsResolutionPass/extract_insert.ll

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
; RUN: igc_opt -platformpvc -igc-joint-matrix-resolution -S 2>&1 < %s | FileCheck %s
1010
; ------------------------------------------------
1111
; JointMatrixFuncsResolutionPass
12+
;
13+
; Test verifies resolution of joint matrix extract and insert functions,
14+
; including adding of joint_matrix_apply metadata.
1215
; ------------------------------------------------
1316

1417
%spirv.JointMatrixINTEL._float_16_16_3_3_2 = type opaque
@@ -18,53 +21,47 @@
1821
; CHECK-SAME: float addrspace(1)* [[PTR1:%.*]], i64 [[IND1:%.*]], float addrspace(1)* [[PTR2:%.*]], i64 [[IND2:%.*]]) {
1922
define spir_kernel void @test(float addrspace(1)* %ptr1, i64 %ind1, float addrspace(1)* %ptr2, i64 %ind2) {
2023
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x <64 x float>]
21-
; CHECK-NEXT: [[TMP2:%.*]] = alloca <16 x float>
24+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x <64 x float>]
25+
; CHECK-NEXT: [[TMP3:%.*]] = alloca <16 x float>
2226

23-
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <16 x float>* [[TMP2]] to i8*
24-
; CHECK-NEXT: call void @__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_global_v8i8_pi32_i32(i8* [[TMP3]], float addrspace(1)* [[PTR1]], i64 32, i32 0)
25-
; CHECK-NEXT: [[TMP4:%.*]] = load <16 x float>, <16 x float>* [[TMP2]]
27+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <16 x float>* [[TMP3]] to i8*
28+
; CHECK-NEXT: call void @__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_global_v8i8_pi32_i32(i8* [[TMP4]], float addrspace(1)* [[PTR1]], i64 32, i32 0)
29+
; CHECK-NEXT: [[TMP5:%.*]] = load <16 x float>, <16 x float>* [[TMP3]]
2630
%C1 = call spir_func %spirv.JointMatrixINTEL._float_16_16_3_3_2 addrspace(1)* @_Z81__spirv_JointMatrixLoadINTEL_RPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2PU3AS1fliii(float addrspace(1)* %ptr1, i64 32, i32 0, i32 3, i32 0)
2731

28-
; CHECK-NEXT: [[MATRIX_ELEMENT:%.*]] = extractelement <16 x float> [[TMP4]], i64 [[IND1]]
32+
; CHECK-NEXT: [[MATRIX_ELEMENT:%.*]] = extractelement <16 x float> [[TMP5]], i64 [[IND1]]
2933
%1 = call spir_func float @_Z28__spirv_VectorExtractDynamicPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2l(%spirv.JointMatrixINTEL._float_16_16_3_3_2 addrspace(1)* %C1, i64 %ind1)
3034

31-
; CHECK-NEXT: [[TMP5:%.*]] = fadd float [[MATRIX_ELEMENT]], 5.000000e+00
35+
; CHECK-NEXT: [[TMP6:%.*]] = fadd float [[MATRIX_ELEMENT]], 5.000000e+00
3236
%2 = fadd float %1, 5.0
3337

34-
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <16 x float> [[TMP4]], float [[TMP5]], i64 [[IND1]]
38+
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <16 x float> [[TMP5]], float [[TMP6]], i64 [[IND1]]
3539
%3 = call spir_func %spirv.JointMatrixINTEL._float_16_16_3_3_2 addrspace(1)* @_Z27__spirv_VectorInsertDynamicPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2fl(%spirv.JointMatrixINTEL._float_16_16_3_3_2 addrspace(1)* %C1, float %2, i64 %ind1)
3640

37-
; CHECK-NEXT: [[TMP7:%.*]] = bitcast [2 x <64 x float>]* [[TMP1]] to i8*
38-
; CHECK-NEXT: call void @__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_global_v8i8_pi32_i32(i8* [[TMP7]], float addrspace(1)* [[PTR2]], i64 128, i32 0)
39-
; CHECK-NEXT: [[TMP8:%.*]] = bitcast [2 x <64 x float>]* [[TMP1]] to <64 x float>*
40-
; CHECK-NEXT: [[TMP9:%.*]] = load <64 x float>, <64 x float>* [[TMP8]]
41-
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr <64 x float>, <64 x float>* [[TMP8]], i32 1
42-
; CHECK-NEXT: [[TMP11:%.*]] = load <64 x float>, <64 x float>* [[TMP10]]
43-
; CHECK-NEXT: [[TMP12:%.*]] = insertvalue [2 x <64 x float>] undef, <64 x float> [[TMP9]], 0
44-
; CHECK-NEXT: [[TMP13:%.*]] = insertvalue [2 x <64 x float>] [[TMP12]], <64 x float> [[TMP11]], 1
41+
; CHECK-NEXT: [[TMP8:%.*]] = bitcast [2 x <64 x float>]* [[TMP2]] to i8*
42+
; CHECK-NEXT: call void @__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_global_v8i8_pi32_i32(i8* [[TMP8]], float addrspace(1)* [[PTR2]], i64 128, i32 0)
43+
; CHECK-NEXT: [[TMP9:%.*]] = bitcast [2 x <64 x float>]* [[TMP2]] to <64 x float>*
44+
; CHECK-NEXT: [[TMP10:%.*]] = load <64 x float>, <64 x float>* [[TMP9]]
45+
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr <64 x float>, <64 x float>* [[TMP9]], i32 1
46+
; CHECK-NEXT: [[TMP12:%.*]] = load <64 x float>, <64 x float>* [[TMP11]]
47+
; CHECK-NEXT: [[TMP13:%.*]] = insertvalue [2 x <64 x float>] undef, <64 x float> [[TMP10]], 0
48+
; CHECK-NEXT: [[TMP14:%.*]] = insertvalue [2 x <64 x float>] [[TMP13]], <64 x float> [[TMP12]], 1
4549
%C2 = call spir_func %spirv.JointMatrixINTEL._float_32_64_3_3_2 addrspace(1)* @_Z81__spirv_JointMatrixLoadINTEL_RPU3AS143__spirv_JointMatrixINTEL__float_32_64_3_3_2PU3AS1fliii(float addrspace(1)* %ptr2, i64 128, i32 0, i32 3, i32 0)
4650

47-
; CHECK-NEXT: [[TMP14:%.*]] = icmp ugt i64 [[IND2]], 63
48-
; CHECK-NEXT: [[MATRIX_SLICE_HALF0:%.*]] = extractvalue [2 x <64 x float>] [[TMP13]], 0
49-
; CHECK-NEXT: [[MATRIX_SLICE_HALF1:%.*]] = extractvalue [2 x <64 x float>] [[TMP13]], 1
50-
; CHECK-NEXT: [[MATRIX_SLICE_SELECTED_HALF:%.*]] = select i1 [[TMP14]], <64 x float> [[MATRIX_SLICE_HALF1]], <64 x float> [[MATRIX_SLICE_HALF0]]
51-
; CHECK-NEXT: [[TMP15:%.*]] = urem i64 [[IND2]], 64
52-
; CHECK-NEXT: [[MATRIX_ELEMENT5:%.*]] = extractelement <64 x float> [[MATRIX_SLICE_SELECTED_HALF]], i64 [[TMP15]]
51+
; CHECK-NEXT: store [2 x <64 x float>] [[TMP14]], [2 x <64 x float>]* [[TMP1]]
52+
; CHECK-NEXT: [[TMP15:%.*]] = bitcast [2 x <64 x float>]* [[TMP1]] to float*
53+
; CHECK-NEXT: [[TMP16:%.*]] = getelementptr float, float* [[TMP15]], i64 [[IND2]]
54+
; CHECK-NEXT: [[TMP17:%.*]] = load float, float* [[TMP16]],{{.*}} !joint_matrix_apply [[MD:![0-9]+]]
5355
%4 = call spir_func float @_Z28__spirv_VectorExtractDynamicPU3AS143__spirv_JointMatrixINTEL__float_32_64_3_3_2l(%spirv.JointMatrixINTEL._float_32_64_3_3_2 addrspace(1)* %C2, i64 %ind2)
5456

55-
; CHECK-NEXT: [[TMP16:%.*]] = fadd float [[MATRIX_ELEMENT5]], 5.000000e+00
57+
; CHECK-NEXT: [[TMP18:%.*]] = fadd float [[TMP17]], 5.000000e+00
5658
%5 = fadd float %4, 5.0
5759

58-
; CHECK-NEXT: [[TMP17:%.*]] = icmp ugt i64 [[IND2]], 63
59-
; CHECK-NEXT: [[MATRIX_SLICE_HALF07:%.*]] = extractvalue [2 x <64 x float>] [[TMP13]], 0
60-
; CHECK-NEXT: [[MATRIX_SLICE_HALF18:%.*]] = extractvalue [2 x <64 x float>] [[TMP13]], 1
61-
; CHECK-NEXT: [[MATRIX_SLICE_SELECTED_HALF9:%.*]] = select i1 [[TMP17]], <64 x float> [[MATRIX_SLICE_HALF18]], <64 x float> [[MATRIX_SLICE_HALF07]]
62-
; CHECK-NEXT: [[TMP18:%.*]] = urem i64 [[IND2]], 64
63-
; CHECK-NEXT: [[TMP19:%.*]] = insertelement <64 x float> [[MATRIX_SLICE_SELECTED_HALF9]], float [[TMP16]], i64 [[TMP18]]
64-
; CHECK-NEXT: [[TMP20:%.*]] = select i1 [[TMP17]], <64 x float> [[MATRIX_SLICE_HALF07]], <64 x float> [[TMP19]]
65-
; CHECK-NEXT: [[TMP21:%.*]] = select i1 [[TMP17]], <64 x float> [[TMP19]], <64 x float> [[MATRIX_SLICE_HALF18]]
66-
; CHECK-NEXT: [[TMP22:%.*]] = insertvalue [2 x <64 x float>] undef, <64 x float> [[TMP20]], 0
67-
; CHECK-NEXT: [[TMP23:%.*]] = insertvalue [2 x <64 x float>] [[TMP22]], <64 x float> [[TMP21]], 1
60+
; CHECK-NEXT: store [2 x <64 x float>] [[TMP14]], [2 x <64 x float>]* [[TMP1]]
61+
; CHECK-NEXT: [[TMP19:%.*]] = bitcast [2 x <64 x float>]* [[TMP1]] to float*
62+
; CHECK-NEXT: [[TMP20:%.*]] = getelementptr float, float* [[TMP19]], i64 [[IND2]]
63+
; CHECK-NEXT: store float [[TMP18]], float* [[TMP20]]
64+
; CHECK-NEXT: [[TMP21:%.*]] = load [2 x <64 x float>], [2 x <64 x float>]* [[TMP1]]
6865
%6 = call spir_func %spirv.JointMatrixINTEL._float_32_64_3_3_2 addrspace(1)* @_Z27__spirv_VectorInsertDynamicPU3AS143__spirv_JointMatrixINTEL__float_32_64_3_3_2fl(%spirv.JointMatrixINTEL._float_32_64_3_3_2 addrspace(1)* %C2, float %5, i64 %ind2)
6966

7067
; CHECK-NEXT: ret void
@@ -79,6 +76,7 @@ declare spir_func %spirv.JointMatrixINTEL._float_32_64_3_3_2 addrspace(1)* @_Z27
7976
declare spir_func %spirv.JointMatrixINTEL._float_32_64_3_3_2 addrspace(1)* @_Z81__spirv_JointMatrixLoadINTEL_RPU3AS143__spirv_JointMatrixINTEL__float_32_64_3_3_2PU3AS1fliii(float addrspace(1)*, i64, i32, i32, i32)
8077
declare spir_func %spirv.JointMatrixINTEL._float_16_16_3_3_2 addrspace(1)* @_Z81__spirv_JointMatrixLoadINTEL_RPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2PU3AS1fliii(float addrspace(1)*, i64, i32, i32, i32)
8178

79+
; CHECK: [[MD]] = !{i1 true}
8280
!igc.functions = !{!0}
8381
!0 = !{void (float addrspace(1)*, i64, float addrspace(1)*, i64)* @test, !1}
8482
!1 = !{!2, !3}

0 commit comments

Comments
 (0)