Skip to content

Commit a1f369e

Browse files
[AArch64][SVE] Add dot product lowering for PARTIAL_REDUCE_MLA node (#130933)
Add lowering in tablegen for PARTIAL_REDUCE_U/SMLA ISD nodes. Only happens when the combine has been performed on the ISD node. Also adds in check to only do the DAG combine when the node can then eventually be lowered, so changes neon tests too. --------- Co-authored-by: James Chesterman <[email protected]>
1 parent 1ce709c commit a1f369e

File tree

9 files changed

+274
-323
lines changed

9 files changed

+274
-323
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,24 @@ class TargetLoweringBase {
16501650
getCondCodeAction(CC, VT) == Custom;
16511651
}
16521652

1653+
/// Return how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input type
1654+
/// InputVT should be treated. Either it's legal, needs to be promoted to a
1655+
/// larger size, needs to be expanded to some other code sequence, or the
1656+
/// target has a custom expander for it.
1657+
LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
1658+
PartialReduceActionTypes TypePair = {AccVT.getSimpleVT().SimpleTy,
1659+
InputVT.getSimpleVT().SimpleTy};
1660+
auto It = PartialReduceMLAActions.find(TypePair);
1661+
return It != PartialReduceMLAActions.end() ? It->second : Expand;
1662+
}
1663+
1664+
/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
1665+
/// legal or custom for this target.
1666+
bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
1667+
LegalizeAction Action = getPartialReduceMLAAction(AccVT, InputVT);
1668+
return Action == Legal || Action == Custom;
1669+
}
1670+
16531671
/// If the action for this operation is to promote, this method returns the
16541672
/// ValueType to promote to.
16551673
MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
@@ -2727,6 +2745,18 @@ class TargetLoweringBase {
27272745
setCondCodeAction(CCs, VT, Action);
27282746
}
27292747

2748+
/// Indicate how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input
2749+
/// type InputVT should be treated by the target. Either it's legal, needs to
2750+
/// be promoted to a larger size, needs to be expanded to some other code
2751+
/// sequence, or the target has a custom expander for it.
2752+
void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
2753+
LegalizeAction Action) {
2754+
assert(AccVT.isValid() && InputVT.isValid() &&
2755+
"setPartialReduceMLAAction types aren't valid");
2756+
PartialReduceActionTypes TypePair = {AccVT.SimpleTy, InputVT.SimpleTy};
2757+
PartialReduceMLAActions[TypePair] = Action;
2758+
}
2759+
27302760
/// If Opc/OrigVT is specified as being promoted, the promotion code defaults
27312761
/// to trying a larger integer/fp until it can find one that works. If that
27322762
/// default is insufficient, this method can be used by the target to override
@@ -3706,6 +3736,13 @@ class TargetLoweringBase {
37063736
/// up the MVT::VALUETYPE_SIZE value to the next multiple of 8.
37073737
uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
37083738

3739+
using PartialReduceActionTypes =
3740+
std::pair<MVT::SimpleValueType, MVT::SimpleValueType>;
3741+
/// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
3742+
/// nodes, keep a LegalizeAction which indicates how instruction selection
3743+
/// should deal with this operation.
3744+
DenseMap<PartialReduceActionTypes, LegalizeAction> PartialReduceMLAActions;
3745+
37093746
ValueTypeActionImpl ValueTypeActions;
37103747

37113748
private:

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ def SDTSubVecInsert : SDTypeProfile<1, 3, [ // subvector insert
313313
SDTCisSubVecOfVec<2, 1>, SDTCisSameAs<0,1>, SDTCisInt<3>
314314
]>;
315315

316+
def SDTPartialReduceMLA : SDTypeProfile<1, 3, [ // partial reduce mla
317+
SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>
318+
]>;
319+
316320
def SDTPrefetch : SDTypeProfile<0, 4, [ // prefetch
317321
SDTCisPtrTy<0>, SDTCisSameAs<1, 2>, SDTCisSameAs<1, 3>, SDTCisInt<1>
318322
]>;
@@ -513,6 +517,11 @@ def vecreduce_fmax : SDNode<"ISD::VECREDUCE_FMAX", SDTFPVecReduce>;
513517
def vecreduce_fminimum : SDNode<"ISD::VECREDUCE_FMINIMUM", SDTFPVecReduce>;
514518
def vecreduce_fmaximum : SDNode<"ISD::VECREDUCE_FMAXIMUM", SDTFPVecReduce>;
515519

520+
def partial_reduce_umla : SDNode<"ISD::PARTIAL_REDUCE_UMLA",
521+
SDTPartialReduceMLA>;
522+
def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
523+
SDTPartialReduceMLA>;
524+
516525
def fadd : SDNode<"ISD::FADD" , SDTFPBinOp, [SDNPCommutative]>;
517526
def fsub : SDNode<"ISD::FSUB" , SDTFPBinOp>;
518527
def fmul : SDNode<"ISD::FMUL" , SDTFPBinOp, [SDNPCommutative]>;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12644,8 +12644,13 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1264412644
if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
1264512645
return SDValue();
1264612646

12647-
// FIXME: Add a check to only perform the DAG combine if there is lowering
12648-
// provided by the target
12647+
// Only perform the DAG combine if there is custom lowering provided by the
12648+
// target
12649+
auto *Context = DAG.getContext();
12650+
if (!TLI.isPartialReduceMLALegalOrCustom(
12651+
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12652+
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12653+
return SDValue();
1264912654

1265012655
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
1265112656

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
469469
case ISD::VECTOR_COMPRESS:
470470
case ISD::SCMP:
471471
case ISD::UCMP:
472-
case ISD::PARTIAL_REDUCE_UMLA:
473-
case ISD::PARTIAL_REDUCE_SMLA:
474472
Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
475473
break;
476474
case ISD::SMULFIX:
@@ -530,6 +528,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
530528
Action = TLI.getOperationAction(Node->getOpcode(), OpVT);
531529
break;
532530
}
531+
case ISD::PARTIAL_REDUCE_UMLA:
532+
case ISD::PARTIAL_REDUCE_SMLA:
533+
Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
534+
Node->getOperand(1).getValueType());
535+
break;
533536

534537
#define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...) \
535538
case ISD::VPID: { \

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -843,10 +843,6 @@ void TargetLoweringBase::initActions() {
843843
setOperationAction(ISD::GET_FPENV, VT, Expand);
844844
setOperationAction(ISD::SET_FPENV, VT, Expand);
845845
setOperationAction(ISD::RESET_FPENV, VT, Expand);
846-
847-
// PartialReduceMLA operations default to expand.
848-
setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
849-
Expand);
850846
}
851847

852848
// Most targets ignore the @llvm.prefetch intrinsic.

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18501850
setOperationAction(ISD::INTRINSIC_WO_CHAIN, VT, Custom);
18511851
}
18521852

1853+
// Handle partial reduction operations
1854+
if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
1855+
// Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
1856+
// Other pairs will default to 'Expand'.
1857+
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
1858+
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
1859+
}
1860+
18531861
// Handle operations that are only available in non-streaming SVE mode.
18541862
if (Subtarget->isSVEAvailable()) {
18551863
for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64,
@@ -1889,7 +1897,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18891897
}
18901898
}
18911899

1892-
18931900
if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
18941901
// Only required for llvm.aarch64.mops.memset.tag
18951902
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,17 @@ let Predicates = [HasSVE_or_SME] in {
653653
defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>;
654654
defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>;
655655

656+
let Predicates = [HasSVE_or_SME] in {
657+
def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv16i8:$MulLHS, nxv16i8:$MulRHS)),
658+
(UDOT_ZZZ_S $Acc, $MulLHS, $MulRHS)>;
659+
def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv16i8:$MulLHS, nxv16i8:$MulRHS)),
660+
(SDOT_ZZZ_S $Acc, $MulLHS, $MulRHS)>;
661+
def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
662+
(UDOT_ZZZ_D $Acc, $MulLHS, $MulRHS)>;
663+
def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
664+
(SDOT_ZZZ_D $Acc, $MulLHS, $MulRHS)>;
665+
} // End HasSVE_or_SME
666+
656667
defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot", int_aarch64_sve_sdot_lane>;
657668
defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot", int_aarch64_sve_udot_lane>;
658669

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 66 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
1212
;
1313
; CHECK-NODOT-LABEL: udot:
1414
; CHECK-NODOT: // %bb.0:
15-
; CHECK-NODOT-NEXT: ushll v3.8h, v1.8b, #0
16-
; CHECK-NODOT-NEXT: ushll v4.8h, v2.8b, #0
17-
; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
18-
; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
19-
; CHECK-NODOT-NEXT: umlal v0.4s, v4.4h, v3.4h
20-
; CHECK-NODOT-NEXT: umull v5.4s, v2.4h, v1.4h
21-
; CHECK-NODOT-NEXT: umlal2 v0.4s, v2.8h, v1.8h
22-
; CHECK-NODOT-NEXT: umlal2 v5.4s, v4.8h, v3.8h
23-
; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
15+
; CHECK-NODOT-NEXT: umull v3.8h, v2.8b, v1.8b
16+
; CHECK-NODOT-NEXT: umull2 v1.8h, v2.16b, v1.16b
17+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
18+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v3.4h
19+
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v2.4s, v3.8h
20+
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
21+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
2422
; CHECK-NODOT-NEXT: ret
2523
%u.wide = zext <16 x i8> %u to <16 x i32>
2624
%s.wide = zext <16 x i8> %s to <16 x i32>
@@ -52,20 +50,18 @@ define <4 x i32> @udot_in_loop(ptr %p1, ptr %p2){
5250
; CHECK-NODOT-NEXT: mov x8, xzr
5351
; CHECK-NODOT-NEXT: .LBB1_1: // %vector.body
5452
; CHECK-NODOT-NEXT: // =>This Inner Loop Header: Depth=1
55-
; CHECK-NODOT-NEXT: ldr q0, [x1, x8]
56-
; CHECK-NODOT-NEXT: ldr q2, [x0, x8]
53+
; CHECK-NODOT-NEXT: ldr q0, [x0, x8]
54+
; CHECK-NODOT-NEXT: ldr q2, [x1, x8]
5755
; CHECK-NODOT-NEXT: add x8, x8, #16
5856
; CHECK-NODOT-NEXT: cmp x8, #16
59-
; CHECK-NODOT-NEXT: ushll2 v3.8h, v0.16b, #0
60-
; CHECK-NODOT-NEXT: ushll2 v4.8h, v2.16b, #0
61-
; CHECK-NODOT-NEXT: ushll v5.8h, v0.8b, #0
62-
; CHECK-NODOT-NEXT: ushll v2.8h, v2.8b, #0
57+
; CHECK-NODOT-NEXT: umull v3.8h, v0.8b, v2.8b
58+
; CHECK-NODOT-NEXT: umull2 v2.8h, v0.16b, v2.16b
6359
; CHECK-NODOT-NEXT: mov v0.16b, v1.16b
64-
; CHECK-NODOT-NEXT: umull v6.4s, v4.4h, v3.4h
65-
; CHECK-NODOT-NEXT: umlal v1.4s, v2.4h, v5.4h
66-
; CHECK-NODOT-NEXT: umlal2 v6.4s, v2.8h, v5.8h
67-
; CHECK-NODOT-NEXT: umlal2 v1.4s, v4.8h, v3.8h
68-
; CHECK-NODOT-NEXT: add v1.4s, v6.4s, v1.4s
60+
; CHECK-NODOT-NEXT: ushll v1.4s, v2.4h, #0
61+
; CHECK-NODOT-NEXT: uaddw v4.4s, v0.4s, v3.4h
62+
; CHECK-NODOT-NEXT: uaddw2 v1.4s, v1.4s, v3.8h
63+
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v4.4s, v2.8h
64+
; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
6965
; CHECK-NODOT-NEXT: b.ne .LBB1_1
7066
; CHECK-NODOT-NEXT: // %bb.2: // %end
7167
; CHECK-NODOT-NEXT: ret
@@ -99,19 +95,17 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
9995
;
10096
; CHECK-NODOT-LABEL: udot_narrow:
10197
; CHECK-NODOT: // %bb.0:
102-
; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
103-
; CHECK-NODOT-NEXT: ushll v2.8h, v2.8b, #0
98+
; CHECK-NODOT-NEXT: umull v1.8h, v2.8b, v1.8b
10499
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
105-
; CHECK-NODOT-NEXT: umull v3.4s, v2.4h, v1.4h
106-
; CHECK-NODOT-NEXT: umull2 v4.4s, v2.8h, v1.8h
107-
; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
108-
; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
109-
; CHECK-NODOT-NEXT: umlal v0.4s, v2.4h, v1.4h
100+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
101+
; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
102+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
103+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
110104
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
111-
; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
112-
; CHECK-NODOT-NEXT: umlal v3.4s, v6.4h, v5.4h
113-
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
105+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
114106
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
107+
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
108+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
115109
; CHECK-NODOT-NEXT: ret
116110
%u.wide = zext <8 x i8> %u to <8 x i32>
117111
%s.wide = zext <8 x i8> %s to <8 x i32>
@@ -128,15 +122,13 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
128122
;
129123
; CHECK-NODOT-LABEL: sdot:
130124
; CHECK-NODOT: // %bb.0:
131-
; CHECK-NODOT-NEXT: sshll v3.8h, v1.8b, #0
132-
; CHECK-NODOT-NEXT: sshll v4.8h, v2.8b, #0
133-
; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
134-
; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
135-
; CHECK-NODOT-NEXT: smlal v0.4s, v4.4h, v3.4h
136-
; CHECK-NODOT-NEXT: smull v5.4s, v2.4h, v1.4h
137-
; CHECK-NODOT-NEXT: smlal2 v0.4s, v2.8h, v1.8h
138-
; CHECK-NODOT-NEXT: smlal2 v5.4s, v4.8h, v3.8h
139-
; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
125+
; CHECK-NODOT-NEXT: smull v3.8h, v2.8b, v1.8b
126+
; CHECK-NODOT-NEXT: smull2 v1.8h, v2.16b, v1.16b
127+
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
128+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v3.4h
129+
; CHECK-NODOT-NEXT: saddw2 v2.4s, v2.4s, v3.8h
130+
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
131+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
140132
; CHECK-NODOT-NEXT: ret
141133
%u.wide = sext <16 x i8> %u to <16 x i32>
142134
%s.wide = sext <16 x i8> %s to <16 x i32>
@@ -153,19 +145,17 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
153145
;
154146
; CHECK-NODOT-LABEL: sdot_narrow:
155147
; CHECK-NODOT: // %bb.0:
156-
; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
157-
; CHECK-NODOT-NEXT: sshll v2.8h, v2.8b, #0
148+
; CHECK-NODOT-NEXT: smull v1.8h, v2.8b, v1.8b
158149
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
159-
; CHECK-NODOT-NEXT: smull v3.4s, v2.4h, v1.4h
160-
; CHECK-NODOT-NEXT: smull2 v4.4s, v2.8h, v1.8h
161-
; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
162-
; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
163-
; CHECK-NODOT-NEXT: smlal v0.4s, v2.4h, v1.4h
150+
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
151+
; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
152+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
153+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
164154
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
165-
; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
166-
; CHECK-NODOT-NEXT: smlal v3.4s, v6.4h, v5.4h
167-
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
155+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
168156
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
157+
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
158+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
169159
; CHECK-NODOT-NEXT: ret
170160
%u.wide = sext <8 x i8> %u to <8 x i32>
171161
%s.wide = sext <8 x i8> %s to <8 x i32>
@@ -417,27 +407,19 @@ define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
417407
;
418408
; CHECK-NODOT-LABEL: udot_8to64:
419409
; CHECK-NODOT: // %bb.0: // %entry
420-
; CHECK-NODOT-NEXT: ushll v4.8h, v3.8b, #0
421-
; CHECK-NODOT-NEXT: ushll v5.8h, v2.8b, #0
422-
; CHECK-NODOT-NEXT: ushll2 v3.8h, v3.16b, #0
423-
; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
424-
; CHECK-NODOT-NEXT: ushll v6.4s, v4.4h, #0
425-
; CHECK-NODOT-NEXT: ushll v7.4s, v5.4h, #0
410+
; CHECK-NODOT-NEXT: umull v4.8h, v2.8b, v3.8b
411+
; CHECK-NODOT-NEXT: umull2 v2.8h, v2.16b, v3.16b
412+
; CHECK-NODOT-NEXT: ushll v3.4s, v4.4h, #0
413+
; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
426414
; CHECK-NODOT-NEXT: ushll2 v4.4s, v4.8h, #0
427-
; CHECK-NODOT-NEXT: ushll2 v5.4s, v5.8h, #0
428-
; CHECK-NODOT-NEXT: ushll2 v16.4s, v3.8h, #0
429-
; CHECK-NODOT-NEXT: ushll2 v17.4s, v2.8h, #0
430-
; CHECK-NODOT-NEXT: ushll v3.4s, v3.4h, #0
431-
; CHECK-NODOT-NEXT: ushll v2.4s, v2.4h, #0
432-
; CHECK-NODOT-NEXT: umlal2 v1.2d, v7.4s, v6.4s
433-
; CHECK-NODOT-NEXT: umlal v0.2d, v7.2s, v6.2s
434-
; CHECK-NODOT-NEXT: umull2 v18.2d, v5.4s, v4.4s
435-
; CHECK-NODOT-NEXT: umull v4.2d, v5.2s, v4.2s
436-
; CHECK-NODOT-NEXT: umlal2 v1.2d, v17.4s, v16.4s
437-
; CHECK-NODOT-NEXT: umlal v0.2d, v17.2s, v16.2s
438-
; CHECK-NODOT-NEXT: umlal2 v18.2d, v2.4s, v3.4s
439-
; CHECK-NODOT-NEXT: umlal v4.2d, v2.2s, v3.2s
440-
; CHECK-NODOT-NEXT: add v1.2d, v18.2d, v1.2d
415+
; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
416+
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v3.4s
417+
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v3.2s
418+
; CHECK-NODOT-NEXT: uaddl2 v3.2d, v4.4s, v5.4s
419+
; CHECK-NODOT-NEXT: uaddl v4.2d, v4.2s, v5.2s
420+
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
421+
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
422+
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
441423
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
442424
; CHECK-NODOT-NEXT: ret
443425
entry:
@@ -460,27 +442,19 @@ define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
460442
;
461443
; CHECK-NODOT-LABEL: sdot_8to64:
462444
; CHECK-NODOT: // %bb.0: // %entry
463-
; CHECK-NODOT-NEXT: sshll v4.8h, v3.8b, #0
464-
; CHECK-NODOT-NEXT: sshll v5.8h, v2.8b, #0
465-
; CHECK-NODOT-NEXT: sshll2 v3.8h, v3.16b, #0
466-
; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
467-
; CHECK-NODOT-NEXT: sshll v6.4s, v4.4h, #0
468-
; CHECK-NODOT-NEXT: sshll v7.4s, v5.4h, #0
445+
; CHECK-NODOT-NEXT: smull v4.8h, v2.8b, v3.8b
446+
; CHECK-NODOT-NEXT: smull2 v2.8h, v2.16b, v3.16b
447+
; CHECK-NODOT-NEXT: sshll v3.4s, v4.4h, #0
448+
; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
469449
; CHECK-NODOT-NEXT: sshll2 v4.4s, v4.8h, #0
470-
; CHECK-NODOT-NEXT: sshll2 v5.4s, v5.8h, #0
471-
; CHECK-NODOT-NEXT: sshll2 v16.4s, v3.8h, #0
472-
; CHECK-NODOT-NEXT: sshll2 v17.4s, v2.8h, #0
473-
; CHECK-NODOT-NEXT: sshll v3.4s, v3.4h, #0
474-
; CHECK-NODOT-NEXT: sshll v2.4s, v2.4h, #0
475-
; CHECK-NODOT-NEXT: smlal2 v1.2d, v7.4s, v6.4s
476-
; CHECK-NODOT-NEXT: smlal v0.2d, v7.2s, v6.2s
477-
; CHECK-NODOT-NEXT: smull2 v18.2d, v5.4s, v4.4s
478-
; CHECK-NODOT-NEXT: smull v4.2d, v5.2s, v4.2s
479-
; CHECK-NODOT-NEXT: smlal2 v1.2d, v17.4s, v16.4s
480-
; CHECK-NODOT-NEXT: smlal v0.2d, v17.2s, v16.2s
481-
; CHECK-NODOT-NEXT: smlal2 v18.2d, v2.4s, v3.4s
482-
; CHECK-NODOT-NEXT: smlal v4.2d, v2.2s, v3.2s
483-
; CHECK-NODOT-NEXT: add v1.2d, v18.2d, v1.2d
450+
; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
451+
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v3.4s
452+
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v3.2s
453+
; CHECK-NODOT-NEXT: saddl2 v3.2d, v4.4s, v5.4s
454+
; CHECK-NODOT-NEXT: saddl v4.2d, v4.2s, v5.2s
455+
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
456+
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
457+
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
484458
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
485459
; CHECK-NODOT-NEXT: ret
486460
entry:
@@ -797,10 +771,9 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
797771
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
798772
; CHECK-LABEL: not_udot:
799773
; CHECK: // %bb.0:
800-
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
801-
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
802-
; CHECK-NEXT: umlal v0.4s, v2.4h, v1.4h
803-
; CHECK-NEXT: umlal2 v0.4s, v2.8h, v1.8h
774+
; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b
775+
; CHECK-NEXT: uaddw v0.4s, v0.4s, v1.4h
776+
; CHECK-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
804777
; CHECK-NEXT: ret
805778
%u.wide = zext <8 x i8> %u to <8 x i32>
806779
%s.wide = zext <8 x i8> %s to <8 x i32>

0 commit comments

Comments
 (0)