Skip to content

Commit 47e82b9

Browse files
zuban32sys_zuul
authored andcommitted
Consider SIMD mask while constfolding
Change-Id: I9d1fb5b2466c61d9a593f65e40c76c245ec74448
1 parent 3f80901 commit 47e82b9

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

IGC/VectorCompiler/lib/GenXOpts/CMAnalysis/ConstantFoldingGenX.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ static Constant *constantFoldWrRegion(Type *RetTy,
135135
const CMRegion &R, const DataLayout &DL) {
136136
Constant *OldValue = Operands[GenXIntrinsic::GenXRegion::OldValueOperandNum];
137137
Constant *NewValue = Operands[GenXIntrinsic::GenXRegion::NewValueOperandNum];
138+
Constant *Mask = Operands[GenXIntrinsic::GenXRegion::PredicateOperandNum];
138139
// The inputs can be ConstantExpr if we are being called from
139140
// CallAnalyzer.
140141
if (isa<ConstantExpr>(OldValue) || isa<ConstantExpr>(NewValue))
@@ -147,7 +148,8 @@ static Constant *constantFoldWrRegion(Type *RetTy,
147148

148149
const int RetElemSize = DL.getTypeSizeInBits(RetTy->getScalarType()) / 8;
149150
unsigned Offset = OffsetC->getSExtValue() / RetElemSize;
150-
if (isa<UndefValue>(OldValue) && R.isContiguous() && (Offset == 0)) {
151+
if (isa<UndefValue>(OldValue) && R.isContiguous() && Offset == 0 &&
152+
Mask->isAllOnesValue()) {
151153
// If old value is undef and new value is splat, and the result vector
152154
// is no bigger than 2 GRFs, then just return a splat of the right type.
153155
Constant *Splat = NewValue;
@@ -172,7 +174,7 @@ static Constant *constantFoldWrRegion(Type *RetTy,
172174
return UndefValue::get(RetTy); // out of range index
173175
if (!isa<VectorType>(NewValue->getType()))
174176
Values[Offset] = NewValue;
175-
else {
177+
else if (!Mask->isZeroValue()) {
176178
unsigned RowIdx = Offset;
177179
unsigned Idx = RowIdx;
178180
unsigned NextRow = R.Width;
@@ -185,7 +187,10 @@ static Constant *constantFoldWrRegion(Type *RetTy,
185187
if (Idx >= WholeNumElements)
186188
// return collected values even if idx is out of bounds
187189
return ConstantVector::get(Values);
188-
Values[Idx] = NewValue->getAggregateElement(i);
190+
if (Mask->isAllOnesValue() ||
191+
(Mask->getType()->isVectorTy() &&
192+
!cast<ConstantVector>(Mask)->getAggregateElement(i)->isZeroValue()))
193+
Values[Idx] = NewValue->getAggregateElement(i);
189194
Idx += R.Stride;
190195
}
191196
}

0 commit comments

Comments
 (0)