Skip to content

Commit 137a062

Browse files
authored
Revert "[RISCV] Initial codegen support for zvqdotq extension (#137039)"
This reverts commit 1ac489c.
1 parent 1ac489c commit 137a062

File tree

4 files changed

+68
-310
lines changed

4 files changed

+68
-310
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 3 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -6971,7 +6971,7 @@ static bool hasPassthruOp(unsigned Opcode) {
69716971
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
69726972
"not a RISC-V target specific op");
69736973
static_assert(
6974-
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
6974+
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
69756975
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
69766976
"adding target specific op should update this function");
69776977
if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
@@ -6995,7 +6995,7 @@ static bool hasMaskOp(unsigned Opcode) {
69956995
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
69966996
"not a RISC-V target specific op");
69976997
static_assert(
6998-
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
6998+
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
69996999
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
70007000
"adding target specific op should update this function");
70017001
if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
@@ -18101,118 +18101,6 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
1810118101
DAG.getBuildVector(VT, DL, RHSOps));
1810218102
}
1810318103

18104-
static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
18105-
const SDLoc &DL, SelectionDAG &DAG,
18106-
const RISCVSubtarget &Subtarget) {
18107-
assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
18108-
RISCVISD::VQDOTSU_VL == Opc);
18109-
MVT VT = Op0.getSimpleValueType();
18110-
assert(VT == Op1.getSimpleValueType() &&
18111-
VT.getVectorElementType() == MVT::i32);
18112-
18113-
assert(VT.isFixedLengthVector());
18114-
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18115-
SDValue Passthru = convertToScalableVector(
18116-
ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
18117-
Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18118-
Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18119-
18120-
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
18121-
const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC;
18122-
SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
18123-
SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
18124-
{Op0, Op1, Passthru, Mask, VL, PolicyOp});
18125-
return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18126-
}
18127-
18128-
static MVT getQDOTXResultType(MVT OpVT) {
18129-
ElementCount OpEC = OpVT.getVectorElementCount();
18130-
assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
18131-
return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
18132-
}
18133-
18134-
static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
18135-
SelectionDAG &DAG,
18136-
const RISCVSubtarget &Subtarget,
18137-
const RISCVTargetLowering &TLI) {
18138-
// Note: We intentionally do not check the legality of the reduction type.
18139-
// We want to handle the m4/m8 *src* types, and thus need to let illegal
18140-
// intermediate types flow through here.
18141-
if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
18142-
!InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
18143-
return SDValue();
18144-
18145-
// reduce (zext a) <--> reduce (mul zext a. zext 1)
18146-
// reduce (sext a) <--> reduce (mul sext a. sext 1)
18147-
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
18148-
InVec.getOpcode() == ISD::SIGN_EXTEND) {
18149-
SDValue A = InVec.getOperand(0);
18150-
if (A.getValueType().getVectorElementType() != MVT::i8 ||
18151-
!TLI.isTypeLegal(A.getValueType()))
18152-
return SDValue();
18153-
18154-
MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
18155-
A = DAG.getBitcast(ResVT, A);
18156-
SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
18157-
18158-
bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
18159-
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
18160-
return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18161-
}
18162-
18163-
// mul (sext, sext) -> vqdot
18164-
// mul (zext, zext) -> vqdotu
18165-
// mul (sext, zext) -> vqdotsu
18166-
// mul (zext, sext) -> vqdotsu (swapped)
18167-
// TODO: Improve .vx handling - we end up with a sub-vector insert
18168-
// which confuses the splat pattern matching. Also, match vqdotus.vx
18169-
if (InVec.getOpcode() != ISD::MUL)
18170-
return SDValue();
18171-
18172-
SDValue A = InVec.getOperand(0);
18173-
SDValue B = InVec.getOperand(1);
18174-
unsigned Opc = 0;
18175-
if (A.getOpcode() == B.getOpcode()) {
18176-
if (A.getOpcode() == ISD::SIGN_EXTEND)
18177-
Opc = RISCVISD::VQDOT_VL;
18178-
else if (A.getOpcode() == ISD::ZERO_EXTEND)
18179-
Opc = RISCVISD::VQDOTU_VL;
18180-
else
18181-
return SDValue();
18182-
} else {
18183-
if (B.getOpcode() != ISD::ZERO_EXTEND)
18184-
std::swap(A, B);
18185-
if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
18186-
return SDValue();
18187-
Opc = RISCVISD::VQDOTSU_VL;
18188-
}
18189-
assert(Opc);
18190-
18191-
if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
18192-
A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
18193-
!TLI.isTypeLegal(A.getValueType()))
18194-
return SDValue();
18195-
18196-
MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
18197-
A = DAG.getBitcast(ResVT, A.getOperand(0));
18198-
B = DAG.getBitcast(ResVT, B.getOperand(0));
18199-
return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18200-
}
18201-
18202-
static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
18203-
const RISCVSubtarget &Subtarget,
18204-
const RISCVTargetLowering &TLI) {
18205-
if (!Subtarget.hasStdExtZvqdotq())
18206-
return SDValue();
18207-
18208-
SDLoc DL(N);
18209-
EVT VT = N->getValueType(0);
18210-
SDValue InVec = N->getOperand(0);
18211-
if (SDValue V = foldReduceOperandViaVQDOT(InVec, DL, DAG, Subtarget, TLI))
18212-
return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, V);
18213-
return SDValue();
18214-
}
18215-
1821618104
static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
1821718105
const RISCVSubtarget &Subtarget,
1821818106
const RISCVTargetLowering &TLI) {
@@ -19990,11 +19878,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1999019878

1999119879
return SDValue();
1999219880
}
19993-
case ISD::VECREDUCE_ADD:
19994-
if (SDValue V = performVECREDUCECombine(N, DAG, Subtarget, *this))
19995-
return V;
19996-
[[fallthrough]];
1999719881
case ISD::CTPOP:
19882+
case ISD::VECREDUCE_ADD:
1999819883
if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
1999919884
return V;
2000019885
break;
@@ -22516,9 +22401,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2251622401
NODE_NAME_CASE(RI_VUNZIP2A_VL)
2251722402
NODE_NAME_CASE(RI_VUNZIP2B_VL)
2251822403
NODE_NAME_CASE(RI_VEXTRACT)
22519-
NODE_NAME_CASE(VQDOT_VL)
22520-
NODE_NAME_CASE(VQDOTU_VL)
22521-
NODE_NAME_CASE(VQDOTSU_VL)
2252222404
NODE_NAME_CASE(READ_CSR)
2252322405
NODE_NAME_CASE(WRITE_CSR)
2252422406
NODE_NAME_CASE(SWAP_CSR)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -416,12 +416,7 @@ enum NodeType : unsigned {
416416
RI_VUNZIP2A_VL,
417417
RI_VUNZIP2B_VL,
418418

419-
// zvqdot instructions with additional passthru, mask and VL operands
420-
VQDOT_VL,
421-
VQDOTU_VL,
422-
VQDOTSU_VL,
423-
424-
LAST_VL_VECTOR_OP = VQDOTSU_VL,
419+
LAST_VL_VECTOR_OP = RI_VUNZIP2B_VL,
425420

426421
// XRivosVisni
427422
// VEXTRACT matches the semantics of ri.vextract.x.v. The result is always

llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,34 +26,3 @@ let Predicates = [HasStdExtZvqdotq] in {
2626
def VQDOTSU_VX : VALUVX<0b101010, OPMVX, "vqdotsu.vx">;
2727
def VQDOTUS_VX : VALUVX<0b101110, OPMVX, "vqdotus.vx">;
2828
} // Predicates = [HasStdExtZvqdotq]
29-
30-
31-
def riscv_vqdot_vl : SDNode<"RISCVISD::VQDOT_VL", SDT_RISCVIntBinOp_VL>;
32-
def riscv_vqdotu_vl : SDNode<"RISCVISD::VQDOTU_VL", SDT_RISCVIntBinOp_VL>;
33-
def riscv_vqdotsu_vl : SDNode<"RISCVISD::VQDOTSU_VL", SDT_RISCVIntBinOp_VL>;
34-
35-
multiclass VPseudoVQDOT_VV_VX {
36-
foreach m = MxSet<32>.m in {
37-
defm "" : VPseudoBinaryV_VV<m>,
38-
SchedBinary<"WriteVIALUV", "ReadVIALUV", "ReadVIALUV", m.MX,
39-
forcePassthruRead=true>;
40-
defm "" : VPseudoBinaryV_VX<m>,
41-
SchedBinary<"WriteVIALUX", "ReadVIALUV", "ReadVIALUX", m.MX,
42-
forcePassthruRead=true>;
43-
}
44-
}
45-
46-
// TODO: Add pseudo and patterns for vqdotus.vx
47-
// TODO: Add isCommutable for VQDOT and VQDOTU
48-
let Predicates = [HasStdExtZvqdotq], mayLoad = 0, mayStore = 0,
49-
hasSideEffects = 0 in {
50-
defm PseudoVQDOT : VPseudoVQDOT_VV_VX;
51-
defm PseudoVQDOTU : VPseudoVQDOT_VV_VX;
52-
defm PseudoVQDOTSU : VPseudoVQDOT_VV_VX;
53-
}
54-
55-
defvar AllE32Vectors = [VI32MF2, VI32M1, VI32M2, VI32M4, VI32M8];
56-
defm : VPatBinaryVL_VV_VX<riscv_vqdot_vl, "PseudoVQDOT", AllE32Vectors>;
57-
defm : VPatBinaryVL_VV_VX<riscv_vqdotu_vl, "PseudoVQDOTU", AllE32Vectors>;
58-
defm : VPatBinaryVL_VV_VX<riscv_vqdotsu_vl, "PseudoVQDOTSU", AllE32Vectors>;
59-

0 commit comments

Comments
 (0)