Skip to content

Commit a91728b

Browse files
admitricigcbot
authored andcommitted
Transform shl+ashr to extractelement for i8
Extend the current functionality to convert shl+ashr to extractelement for i8 type. It helps to save latency on arithmetic instructions by using register regioning
1 parent 18f84d6 commit a91728b

File tree

2 files changed

+151
-19
lines changed

2 files changed

+151
-19
lines changed

IGC/Compiler/CustomSafeOptPass.cpp

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2834,7 +2834,7 @@ void GenSpecificPattern::visitAnd(BinaryOperator& I)
28342834
}
28352835
}
28362836

2837-
void GenSpecificPattern::visitAShr(BinaryOperator& I)
2837+
void GenSpecificPattern::visitAShr(BinaryOperator &I)
28382838
{
28392839
/*
28402840
From:
@@ -2850,36 +2850,70 @@ void GenSpecificPattern::visitAShr(BinaryOperator& I)
28502850
%132 = sext i8 %ee1 to i32
28512851
%133 = sext i8 %ee2 to i32
28522852
Which will end up as regioning instead of 2 isntr.
2853+
2854+
Also change shl 24 + asr 24 -> extractelement <4 x i8> %temp, i32 0
28532855
*/
28542856

28552857
llvm::IRBuilder<> builder(&I);
28562858
using namespace llvm::PatternMatch;
28572859

2858-
Instruction* AShrSrc = nullptr;
2859-
auto pattern_1 = m_AShr(m_Instruction(AShrSrc), m_SpecificInt(16));
2860-
2861-
if (match(&I, pattern_1) && I.getType()->isIntegerTy(32) && AShrSrc && AShrSrc->getType()->isIntegerTy(32))
2860+
auto tryTransformAsrToEE = [&](Instruction &I, uint32_t BaseTypeSize, uint32_t ElemSize)
28622861
{
2863-
Instruction* ShlSrc = nullptr;
2862+
IGC_ASSERT(BaseTypeSize % ElemSize == 0);
2863+
2864+
auto *BaseType = builder.getIntNTy(BaseTypeSize);
2865+
if (I.getType() != BaseType)
2866+
return false;
28642867

2865-
auto Shl_Pattern = m_Shl(m_Instruction(ShlSrc), m_SpecificInt(16));
2866-
bool submatch = match(AShrSrc, Shl_Pattern) && ShlSrc && ShlSrc->getType()->isIntegerTy(32);
2868+
Value *AShrSrc = nullptr;
2869+
uint32_t ShiftBits = BaseTypeSize - ElemSize;
2870+
auto AShrPattern = m_AShr(m_Value(AShrSrc), m_SpecificInt(ShiftBits));
28672871

2868-
// in case there's no shr, we take upper half
2869-
uint32_t newIndex = 1;
2872+
if (!match(&I, AShrPattern))
2873+
return false;
2874+
if (!AShrSrc || AShrSrc->getType() != BaseType)
2875+
return false;
28702876

2871-
// if there was Shl, we take lower half
2872-
if (submatch)
2877+
Value *ShlSrc = nullptr;
2878+
2879+
auto ShlPattern = m_Shl(m_Value(ShlSrc), m_SpecificInt(ShiftBits));
2880+
bool ShlMatch = match(AShrSrc, ShlPattern) && ShlSrc && ShlSrc->getType() == BaseType;
2881+
2882+
uint32_t Index = 0;
2883+
Value *BaseValue = nullptr;
2884+
if (ShlMatch)
28732885
{
2874-
AShrSrc = ShlSrc;
2875-
newIndex = 0;
2886+
BaseValue = ShlSrc;
2887+
Index = 0;
28762888
}
2877-
VectorType* vec2 = VectorType::get(builder.getInt16Ty(), 2, false);
2878-
Value* BC = builder.CreateBitCast(AShrSrc, vec2);
2879-
Value* EE = builder.CreateExtractElement(BC, builder.getInt32(newIndex));
2880-
Value* Sext = builder.CreateSExt(EE, builder.getInt32Ty());
2881-
I.replaceAllUsesWith(Sext);
2889+
else if (ShiftBits * 2 == BaseTypeSize)
2890+
{
2891+
// if Shl is not matched we can still make an EE on the AShr source
2892+
// but extract the upper half. Check we shift exactly the half bits
2893+
BaseValue = AShrSrc;
2894+
Index = 1;
2895+
}
2896+
else
2897+
{
2898+
return false;
2899+
}
2900+
2901+
VectorType *Vec = VectorType::get(builder.getIntNTy(ElemSize), BaseTypeSize / ElemSize, false);
2902+
Value* BC = builder.CreateBitCast(BaseValue, Vec);
2903+
Value* EE = builder.CreateExtractElement(BC, builder.getIntN(BaseTypeSize, Index));
2904+
Value* SExt = builder.CreateSExt(EE, BaseType);
2905+
I.replaceAllUsesWith(SExt);
28822906
I.eraseFromParent();
2907+
2908+
return true;
2909+
};
2910+
2911+
tryTransformAsrToEE(I, 32, 16);
2912+
2913+
CodeGenContext *CTX = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
2914+
if (CTX->platform.supportByteALUOperation())
2915+
{
2916+
tryTransformAsrToEE(I, 32, 8);
28832917
}
28842918
}
28852919

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
;=========================== begin_copyright_notice ============================
2+
;
3+
; Copyright (C) 2024 Intel Corporation
4+
;
5+
; SPDX-License-Identifier: MIT
6+
;
7+
;============================ end_copyright_notice =============================
8+
9+
; RUN: igc_opt -igc-gen-specific-pattern --verify --platformdg2 -S %s | FileCheck %s
10+
11+
define spir_kernel void @testkernel_16(i32 %x) {
12+
; CHECK-LABEL: @testkernel_16(
13+
; CHECK: entry:
14+
; CHECK: [[SHL:%.*]] = shl i32 [[X:%.*]], 16
15+
; CHECK: [[TMP0:%.*]] = bitcast i32 [[X]] to <2 x i16>
16+
; CHECK: [[TMP1:%.*]] = extractelement <2 x i16> [[TMP0]], i32 0
17+
; CHECK: [[TMP2:%.*]] = sext i16 [[TMP1]] to i32
18+
; CHECK: [[TMP3:%.*]] = bitcast i32 [[X]] to <2 x i16>
19+
; CHECK: [[TMP4:%.*]] = extractelement <2 x i16> [[TMP3]], i32 1
20+
; CHECK: [[TMP5:%.*]] = sext i16 [[TMP4]] to i32
21+
; CHECK: [[RES:%.*]] = add i32 [[TMP2]], [[TMP5]]
22+
; CHECK: ret void
23+
;
24+
25+
entry:
26+
%Shl = shl i32 %x, 16
27+
%Lo = ashr exact i32 %Shl, 16
28+
%Hi = ashr i32 %x, 16
29+
%Res = add i32 %Lo, %Hi
30+
ret void
31+
}
32+
33+
define spir_kernel void @testkernel_8(i32 %x, <4 x i32> %y) {
34+
; CHECK-LABEL: @testkernel_8(
35+
; CHECK: entry:
36+
37+
; Not changed part of the code - listed common pattern just for bigger picture
38+
39+
; CHECK: [[ASTYPE:%.*]] = bitcast i32 [[X:%.*]] to <4 x i8>
40+
; CHECK: [[ASTYPE_SCALAR111:%.*]] = extractelement <4 x i8> [[ASTYPE]], i64 1
41+
; CHECK: [[ASTYPE_SCALAR112:%.*]] = extractelement <4 x i8> [[ASTYPE]], i64 2
42+
; CHECK: [[ASTYPE_SCALAR113:%.*]] = extractelement <4 x i8> [[ASTYPE]], i64 3
43+
44+
; This is not used anymore
45+
; CHECK: [[SEXT:%.*]] = shl i32 [[X]], 24
46+
47+
; i8 extract element is added
48+
; CHECK: [[TMP0:%.*]] = bitcast i32 [[X]] to <4 x i8>
49+
; CHECK: [[TMP1:%.*]] = extractelement <4 x i8> [[TMP0]], i32 0
50+
; CHECK: [[CONV1:%.*]] = sext i8 [[TMP1]] to i32
51+
52+
; not changed
53+
; CHECK: [[SCALAR1:%.*]] = extractelement <4 x i32> [[Y:%.*]], i64 0
54+
55+
; Ensure EE is used (Conv1)
56+
; CHECK: [[ADD_I1:%.*]] = add nsw i32 [[SCALAR1]], [[CONV1]]
57+
; CHECK: [[SCALAR2:%.*]] = extractelement <4 x i32> [[Y]], i64 1
58+
59+
; The rest of the code, listed common pattern for bigger picture
60+
; CHECK: [[CONV2:%.*]] = sext i8 [[ASTYPE_SCALAR111]] to i32
61+
; CHECK: [[ADD_I2:%.*]] = add nsw i32 [[SCALAR2]], [[CONV2]]
62+
; CHECK: [[SCALAR3:%.*]] = extractelement <4 x i32> [[Y]], i64 2
63+
; CHECK: [[CONV3:%.*]] = sext i8 [[ASTYPE_SCALAR112]] to i32
64+
; CHECK: [[ADD_I3:%.*]] = add nsw i32 [[SCALAR3]], [[CONV3]]
65+
; CHECK: [[SCALAR4:%.*]] = extractelement <4 x i32> [[Y]], i64 3
66+
; CHECK: [[CONV4:%.*]] = sext i8 [[ASTYPE_SCALAR113]] to i32
67+
; CHECK: [[ADD_I4:%.*]] = add nsw i32 [[SCALAR4]], [[CONV4]]
68+
; CHECK: ret void
69+
;
70+
entry:
71+
%astype = bitcast i32 %x to <4 x i8>
72+
%astype.scalar111 = extractelement <4 x i8> %astype, i64 1
73+
%astype.scalar112 = extractelement <4 x i8> %astype, i64 2
74+
%astype.scalar113 = extractelement <4 x i8> %astype, i64 3
75+
76+
; The transformation applies to this part
77+
; ASHR will be changed to EE
78+
%sext = shl i32 %x, 24
79+
%conv1 = ashr exact i32 %sext, 24
80+
81+
%scalar1 = extractelement <4 x i32> %y, i64 0
82+
%add.i1 = add nsw i32 %scalar1, %conv1
83+
84+
%scalar2 = extractelement <4 x i32> %y, i64 1
85+
%conv2 = sext i8 %astype.scalar111 to i32
86+
%add.i2 = add nsw i32 %scalar2, %conv2
87+
88+
%scalar3 = extractelement <4 x i32> %y, i64 2
89+
%conv3 = sext i8 %astype.scalar112 to i32
90+
%add.i3 = add nsw i32 %scalar3, %conv3
91+
92+
%scalar4 = extractelement <4 x i32> %y, i64 3
93+
%conv4 = sext i8 %astype.scalar113 to i32
94+
%add.i4 = add nsw i32 %scalar4, %conv4
95+
96+
ret void
97+
}
98+

0 commit comments

Comments
 (0)