Skip to content

Commit 79fd221

Browse files
committed
[X86] Use GFNI for vXi8 per-element shifts
As detailed here: https://github.com/InstLatx64/InstLatX64_Demo/blob/master/GFNI_Demo.h These are a bit more complicated than gf2p8affine look ups, requiring us to convert a SHL shift value / amount into a GF so we can perform a multiplication. SRL/SRA need to be converted to SHL via bitreverse/variable-sign-extension. Followup to #89115
1 parent 2265df9 commit 79fd221

File tree

3 files changed

+1435
-2411
lines changed

3 files changed

+1435
-2411
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29586,6 +29586,62 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
2958629586
DAG.getNode(Opc, dl, ExtVT, R, Amt));
2958729587
}
2958829588

29589+
// GFNI - we can perform SHL with a GF multiplication, and can convert
29590+
// SRL/SRA to a SHL.
29591+
if (VT == MVT::v16i8 ||
29592+
(VT == MVT::v32i8 && Subtarget.hasInt256() && !Subtarget.hasXOP()) ||
29593+
(VT == MVT::v64i8 && Subtarget.hasBWI())) {
29594+
if (Subtarget.hasGFNI() && Subtarget.hasSSSE3()) {
29595+
auto GFShiftLeft = [&](SDValue Val) {
29596+
// Use PSHUFB as a LUT from the shift amount to create a per-element
29597+
// byte mask for the shift value and an index. For shift amounts greater
29598+
// than 7, the result will be zero.
29599+
SmallVector<APInt, 8> MaskBits, IdxBits;
29600+
for (unsigned I = 0, E = VT.getSizeInBits() / 128; I != E; ++I) {
29601+
MaskBits.push_back(APInt(64, 0x0103070F1F3F7FFFULL));
29602+
IdxBits.push_back(APInt(64, 0x8040201008040201ULL));
29603+
MaskBits.push_back(APInt::getZero(64));
29604+
IdxBits.push_back(APInt::getZero(64));
29605+
}
29606+
29607+
MVT CVT = MVT::getVectorVT(MVT::i64, VT.getSizeInBits() / 64);
29608+
SDValue Mask =
29609+
DAG.getBitcast(VT, getConstVector(MaskBits, CVT, DAG, dl));
29610+
SDValue Idx = DAG.getBitcast(VT, getConstVector(IdxBits, CVT, DAG, dl));
29611+
Mask = DAG.getNode(X86ISD::PSHUFB, dl, VT, Mask, Amt);
29612+
Idx = DAG.getNode(X86ISD::PSHUFB, dl, VT, Idx, Amt);
29613+
Mask = DAG.getNode(ISD::AND, dl, VT, Val, Mask);
29614+
return DAG.getNode(X86ISD::GF2P8MULB, dl, VT, Mask, Idx);
29615+
};
29616+
29617+
if (Opc == ISD::SHL)
29618+
return GFShiftLeft(R);
29619+
29620+
// srl(x,y)
29621+
// --> bitreverse(shl(bitreverse(x),y))
29622+
if (Opc == ISD::SRL) {
29623+
R = DAG.getNode(ISD::BITREVERSE, dl, VT, R);
29624+
R = GFShiftLeft(R);
29625+
return DAG.getNode(ISD::BITREVERSE, dl, VT, R);
29626+
}
29627+
29628+
// sra(x,y)
29629+
// --> sub(xor(srl(x,y), m),m)
29630+
// --> sub(xor(bitreverse(shl(bitreverse(x),y)), m),m)
29631+
// where m = srl(signbit, amt) --> bitreverse(shl(lsb, amt))
29632+
if (Opc == ISD::SRA) {
29633+
SDValue LSB = DAG.getConstant(APInt::getOneBitSet(8, 0), dl, VT);
29634+
SDValue M = DAG.getNode(ISD::BITREVERSE, dl, VT, GFShiftLeft(LSB));
29635+
R = DAG.getNode(ISD::BITREVERSE, dl, VT, R);
29636+
R = GFShiftLeft(R);
29637+
R = DAG.getNode(ISD::BITREVERSE, dl, VT, R);
29638+
R = DAG.getNode(ISD::XOR, dl, VT, R, M);
29639+
R = DAG.getNode(ISD::SUB, dl, VT, R, M);
29640+
return R;
29641+
}
29642+
}
29643+
}
29644+
2958929645
// Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors as we
2959029646
// extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI.
2959129647
if (ConstantAmt && (Opc == ISD::SRA || Opc == ISD::SRL) &&
@@ -55645,6 +55701,15 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
5564555701
ConcatSubOperand(VT, Ops, 0));
5564655702
}
5564755703
break;
55704+
case X86ISD::GF2P8MULB:
55705+
if (!IsSplat &&
55706+
(VT.is256BitVector() ||
55707+
(VT.is512BitVector() && Subtarget.useAVX512Regs()))) {
55708+
return DAG.getNode(Op0.getOpcode(), DL, VT,
55709+
ConcatSubOperand(VT, Ops, 0),
55710+
ConcatSubOperand(VT, Ops, 1));
55711+
}
55712+
break;
5564855713
case X86ISD::GF2P8AFFINEQB:
5564955714
if (!IsSplat &&
5565055715
(VT.is256BitVector() ||

0 commit comments

Comments
 (0)