Skip to content

Commit 804ed3c

Browse files
jgu222igcbot
authored andcommitted
symexpr handles or and shl
Improve symexpr to handle some special instructions: 1. or with a constant : change to add if it can be, and 2. shl by a constant : change to mul
1 parent c82d949 commit 804ed3c

File tree

2 files changed

+140
-0
lines changed

2 files changed

+140
-0
lines changed

IGC/Compiler/CISACodeGen/SLMConstProp.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,36 @@ void SymbolicEvaluation::getSymExprOrConstant(const Value* V, SymExpr*& S, int64
372372

373373
if (const ConstantInt * CI = dyn_cast<const ConstantInt>(V))
374374
{
375+
// symExpr handles symbols with the same bit size, thus sext/zext/trunc
376+
// are not handled. With this, a result from either signed or unsigned
377+
// integer operations will end up with the same bit pattern. Here, we
378+
// choose to use sext on constants.
375379
C = CI->getSExtValue();
376380
return;
377381
}
378382

383+
// Used for nomalizing shift amount.
384+
// For example, i64, its type mask is 03F(63 = 64 - 1).
385+
auto getTypeMask = [](const Type* Ty) -> uint32_t {
386+
// For simplicity, only handle type whose size is power of 2.
387+
uint32_t nbits = Ty->getScalarSizeInBits();
388+
if (nbits > 0 && isPowerOf2_32(nbits))
389+
return (nbits - 1);
390+
return 0;
391+
};
392+
393+
// Return value:
394+
// Shift amount: if it is valid and greater than 0
395+
// 0 : invalid
396+
auto getShlAmt = [&getTypeMask](const Instruction* ShlInst) -> uint32_t {
397+
IGC_ASSERT(ShlInst->getOpcode() == Instruction::Shl);
398+
uint32_t shtAmtMask = getTypeMask(ShlInst->getType());
399+
ConstantInt* cI = cast<ConstantInt>(ShlInst->getOperand(1));
400+
if (cI && shtAmtMask > 0)
401+
return (uint32_t)(cI->getZExtValue() & shtAmtMask);
402+
return 0;
403+
};
404+
379405
// Instructions/Operators handled for now:
380406
// GEP
381407
// bitcast (inttoptr, ptrtoint, etc)
@@ -515,6 +541,35 @@ void SymbolicEvaluation::getSymExprOrConstant(const Value* V, SymExpr*& S, int64
515541
}
516542
break;
517543
}
544+
case Instruction::Or:
545+
{
546+
// Check if it is actually an add.
547+
//
548+
// %mul = shl nuw nsw i64 %v, 1
549+
// %add = or i64 %mul, 1
550+
// --> %add = add %mul, 1
551+
const Value* V0 = Op->getOperand(0);
552+
const Value* V1 = Op->getOperand(1);
553+
getSymExprOrConstant(V0, S0, C0);
554+
getSymExprOrConstant(V1, S1, C1);
555+
if (!S0 && !S1) {
556+
C = C0 | C1;
557+
return;
558+
}
559+
560+
// Case: 'or V0 Const' or 'or const V1'
561+
if ((S0 && !S1) || (!S0 && S1)) {
562+
const Value* tV = (S0 ? V0 : V1);
563+
const uint64_t tC = (uint64_t)(S0 ? C1 : C0);
564+
const Instruction* tI = dyn_cast<Instruction>(tV);
565+
if (tI && tI->getOpcode() == Instruction::Shl) {
566+
uint32_t shtAmt = getShlAmt(tI);
567+
if (shtAmt > 0 && (1ull << shtAmt) > tC)
568+
S = add(S0 ? S0 : S1, tC);
569+
}
570+
}
571+
break;
572+
}
518573
case Instruction::Mul:
519574
{
520575
const Value* V0 = Op->getOperand(0);
@@ -538,7 +593,34 @@ void SymbolicEvaluation::getSymExprOrConstant(const Value* V, SymExpr*& S, int64
538593

539594
break;
540595
}
596+
case Instruction::Shl:
597+
{
598+
// shl is a mul
599+
// shl a, b, 2
600+
// -> mul a, b, (1 << 2)
601+
const Value* V0 = Op->getOperand(0);
602+
const Value* V1 = Op->getOperand(1);
603+
getSymExprOrConstant(V0, S0, C0);
604+
getSymExprOrConstant(V1, S1, C1);
605+
606+
uint32_t shtAmtMask = getTypeMask(V->getType());
607+
if (shtAmtMask == 0) // sanity
608+
break;
541609

610+
if (!S1) {
611+
C1 = (C1 & shtAmtMask);
612+
}
613+
614+
if (!S0 && !S1) {
615+
C = (C0 << C1);
616+
return;
617+
}
618+
if (!S1) {
619+
uint64_t tC = (1ull << C1);
620+
S = mul(S0, tC);
621+
}
622+
break;
623+
}
542624
case Instruction::BitCast:
543625
case Instruction::IntToPtr:
544626
case Instruction::PtrToInt:
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
;=========================== begin_copyright_notice ============================
2+
;
3+
; Copyright (C) 2017-2023 Intel Corporation
4+
;
5+
; SPDX-License-Identifier: MIT
6+
;
7+
;============================ end_copyright_notice =============================
8+
9+
10+
11+
12+
; Given a0 = shl a, 1
13+
; a1 = getelementptr inbounds <2 x i32>, <2 x i32> addrspace(1)* %d, i64 %a0
14+
; store <2 x i32> %v0, <2 x i32> addrspace(1)* %a1, align 8
15+
; a2 = or i64 %a0, 1
16+
; a3 = getelementptr inbounds <2 x i32>, <2 x i32> addrspace(1)* %d, i64 %a2
17+
; store <2 x i32> %v1, <2 x i32> addrspace(1)* %a3, align 8
18+
; combined into
19+
; store <4xi32>
20+
;
21+
; This is to test that symbolic expression can handle 'or' and 'shl' instructions
22+
;
23+
; CHECK-LABEL: target datalayout
24+
; CHECK: %__StructSOALayout_ = type <{ <2 x i32>, <2 x i32> }>
25+
; CHECK-LABEL: define spir_kernel void @test_st
26+
; CHECK: load <4 x i32>,
27+
; CHECK: [[TMP1:%.*]] = insertvalue %__StructSOALayout_ undef, <2 x i32> %{{.*}}, 0
28+
; CHECK: [[TMP2:%.*]] = insertvalue %__StructSOALayout_ [[TMP1]], <2 x i32> %{{.*}}, 1
29+
; CHECK: [[TMP3:%.*]] = call <4 x i32> @llvm.genx.GenISA.bitcastfromstruct.v4i32.__StructSOALayout_(%__StructSOALayout_ [[TMP2]])
30+
; CHECK: store <4 x i32> [[TMP3]]
31+
; CHECK: ret void
32+
33+
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32"
34+
target triple = "spir64-unknown-unknown"
35+
36+
; Function Attrs: convergent nounwind
37+
define spir_kernel void @test_st(<2 x i32> addrspace(1)* %d, <4 x i32> addrspace(1)* %s, <8 x i32> %r0, <8 x i32> %payloadHeader, <3 x i32> %enqueuedLocalSize, i16 %localIdX, i16 %localIdY, i16 %localIdZ) #0 {
38+
entry:
39+
%payloadHeader.scalar = extractelement <8 x i32> %payloadHeader, i32 0
40+
%enqueuedLocalSize.scalar = extractelement <3 x i32> %enqueuedLocalSize, i32 0
41+
%r0.scalar17 = extractelement <8 x i32> %r0, i32 1
42+
%mul.i.i.i = mul i32 %enqueuedLocalSize.scalar, %r0.scalar17
43+
%localIdX2 = zext i16 %localIdX to i32
44+
%add.i.i.i = add i32 %mul.i.i.i, %localIdX2
45+
%add4.i.i.i = add i32 %add.i.i.i, %payloadHeader.scalar
46+
%conv.i.i.i = zext i32 %add4.i.i.i to i64
47+
%mul = shl nuw nsw i64 %conv.i.i.i, 1
48+
%arrayidx = getelementptr inbounds <4 x i32>, <4 x i32> addrspace(1)* %s, i64 %mul
49+
%0 = load <4 x i32>, <4 x i32> addrspace(1)* %arrayidx, align 16
50+
%vecinit1.assembled.vect36 = shufflevector <4 x i32> %0, <4 x i32> undef, <2 x i32> <i32 0, i32 1>
51+
%vecinit4.assembled.vect37 = shufflevector <4 x i32> %0, <4 x i32> undef, <2 x i32> <i32 2, i32 3>
52+
%arrayidx5 = getelementptr inbounds <2 x i32>, <2 x i32> addrspace(1)* %d, i64 %mul
53+
store <2 x i32> %vecinit1.assembled.vect36, <2 x i32> addrspace(1)* %arrayidx5, align 8
54+
%add = or i64 %mul, 1
55+
%arrayidx6 = getelementptr inbounds <2 x i32>, <2 x i32> addrspace(1)* %d, i64 %add
56+
store <2 x i32> %vecinit4.assembled.vect37, <2 x i32> addrspace(1)* %arrayidx6, align 8
57+
ret void
58+
}

0 commit comments

Comments
 (0)