Skip to content

[DAGCombine] Simplify partial_reduce_*mla with constant. #138289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 34 additions & 20 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12612,47 +12612,63 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}

// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
// Splat(1)) into
// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
// Splat(1)) into
// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
// -> partial_reduce_*mla(acc, a, b)
//
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
SDLoc DL(N);

auto *Context = DAG.getContext();
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);

APInt ConstantOne;
APInt C;
if (Op1->getOpcode() != ISD::MUL ||
!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
!ConstantOne.isOne())
!ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
return SDValue();

SDValue LHS = Op1->getOperand(0);
SDValue RHS = Op1->getOperand(1);
unsigned LHSOpcode = LHS->getOpcode();
unsigned RHSOpcode = RHS->getOpcode();
if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
if (!ISD::isExtOpcode(LHSOpcode))
return SDValue();

SDValue LHSExtOp = LHS->getOperand(0);
SDValue RHSExtOp = RHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
return SDValue();

// Only perform the DAG combine if there is custom lowering provided by the
// target
auto *Context = DAG.getContext();
// Only perform these combines if the target supports folding
// the extends into the operation.
if (!TLI.isPartialReduceMLALegalOrCustom(
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();

bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
unsigned NewOpcode =
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;

// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
(LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
return SDValue();

return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
}

unsigned RHSOpcode = RHS->getOpcode();
if (!ISD::isExtOpcode(RHSOpcode))
return SDValue();

SDValue RHSExtOp = RHS->getOperand(0);
if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
return SDValue();

// For a 2-stage extend the signedness of both of the extends must be the
// same. This is so the node can be folded into only a signed or unsigned
Expand All @@ -12663,8 +12679,6 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();

unsigned NewOpcode =
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
RHSExtOp);
}
Expand Down
143 changes: 142 additions & 1 deletion llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,6 @@ entry:
ret <vscale x 2 x i16> %partial.reduce
}


define <vscale x 4 x i64> @partial_reduce_only_split_acc(<vscale x 4 x i64> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: partial_reduce_only_split_acc:
; CHECK: // %bb.0: // %entry
Expand Down Expand Up @@ -1178,3 +1177,145 @@ entry:
<vscale x 4 x i64> %acc, <vscale x 8 x i64> %mult)
ret <vscale x 4 x i64> %partial.reduce
}

define <vscale x 4 x i32> @sdot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
; CHECK-LABEL: sdot_imm:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sunpklo z2.h, z1.b
; CHECK-NEXT: sunpkhi z1.h, z1.b
; CHECK-NEXT: sunpklo z3.s, z2.h
; CHECK-NEXT: sunpkhi z2.s, z2.h
; CHECK-NEXT: sub z0.s, z0.s, z3.s
; CHECK-NEXT: sunpklo z3.s, z1.h
; CHECK-NEXT: sunpkhi z1.s, z1.h
; CHECK-NEXT: sub z0.s, z0.s, z2.s
; CHECK-NEXT: sub z0.s, z0.s, z3.s
; CHECK-NEXT: sub z0.s, z0.s, z1.s
; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: sdot_imm:
; CHECK-NEWLOWERING: // %bb.0: // %entry
; CHECK-NEWLOWERING-NEXT: mov z2.b, #-1 // =0xffffffffffffffff
; CHECK-NEWLOWERING-NEXT: sdot z0.s, z1.b, z2.b
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 -1)
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 4 x i32> @sdot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
; CHECK-LABEL: sdot_imm_does_not_fit:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sunpklo z2.h, z1.b
; CHECK-NEXT: sunpkhi z1.h, z1.b
; CHECK-NEXT: sunpklo z3.s, z2.h
; CHECK-NEXT: sunpkhi z2.s, z2.h
; CHECK-NEXT: sunpklo z4.s, z1.h
; CHECK-NEXT: sunpkhi z1.s, z1.h
; CHECK-NEXT: lsl z4.s, z4.s, #8
; CHECK-NEXT: lsl z2.s, z2.s, #8
; CHECK-NEXT: lsl z3.s, z3.s, #8
; CHECK-NEXT: lsl z1.s, z1.s, #8
; CHECK-NEXT: add z0.s, z0.s, z3.s
; CHECK-NEXT: add z2.s, z2.s, z4.s
; CHECK-NEXT: add z0.s, z0.s, z2.s
; CHECK-NEXT: add z0.s, z0.s, z1.s
; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: sdot_imm_does_not_fit:
; CHECK-NEWLOWERING: // %bb.0: // %entry
; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z1.b
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z1.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
; CHECK-NEWLOWERING-NEXT: lsl z4.s, z4.s, #8
; CHECK-NEWLOWERING-NEXT: lsl z2.s, z2.s, #8
; CHECK-NEWLOWERING-NEXT: lsl z3.s, z3.s, #8
; CHECK-NEWLOWERING-NEXT: lsl z1.s, z1.s, #8
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
; CHECK-NEWLOWERING-NEXT: add z2.s, z2.s, z4.s
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 4 x i32> @udot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
; CHECK-LABEL: udot_imm:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: uunpklo z3.h, z1.b
; CHECK-NEXT: mov z2.s, #255 // =0xff
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: uunpkhi z1.h, z1.b
; CHECK-NEXT: uunpklo z4.s, z3.h
; CHECK-NEXT: uunpkhi z3.s, z3.h
; CHECK-NEXT: mla z0.s, p0/m, z4.s, z2.s
; CHECK-NEXT: uunpklo z4.s, z1.h
; CHECK-NEXT: uunpkhi z1.s, z1.h
; CHECK-NEXT: mla z0.s, p0/m, z3.s, z2.s
; CHECK-NEXT: mla z0.s, p0/m, z4.s, z2.s
; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: udot_imm:
; CHECK-NEWLOWERING: // %bb.0: // %entry
; CHECK-NEWLOWERING-NEXT: mov z2.b, #-1 // =0xffffffffffffffff
; CHECK-NEWLOWERING-NEXT: udot z0.s, z1.b, z2.b
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 255)
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 4 x i32> @udot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
; CHECK-LABEL: udot_imm_does_not_fit:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: uunpklo z2.h, z1.b
; CHECK-NEXT: uunpkhi z1.h, z1.b
; CHECK-NEXT: uunpklo z3.s, z2.h
; CHECK-NEXT: uunpkhi z2.s, z2.h
; CHECK-NEXT: uunpklo z4.s, z1.h
; CHECK-NEXT: uunpkhi z1.s, z1.h
; CHECK-NEXT: lsl z4.s, z4.s, #8
; CHECK-NEXT: lsl z2.s, z2.s, #8
; CHECK-NEXT: lsl z3.s, z3.s, #8
; CHECK-NEXT: lsl z1.s, z1.s, #8
; CHECK-NEXT: add z0.s, z0.s, z3.s
; CHECK-NEXT: add z2.s, z2.s, z4.s
; CHECK-NEXT: add z0.s, z0.s, z2.s
; CHECK-NEXT: add z0.s, z0.s, z1.s
; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: udot_imm_does_not_fit:
; CHECK-NEWLOWERING: // %bb.0: // %entry
; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z1.b
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z1.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
; CHECK-NEWLOWERING-NEXT: lsl z4.s, z4.s, #8
; CHECK-NEWLOWERING-NEXT: lsl z2.s, z2.s, #8
; CHECK-NEWLOWERING-NEXT: lsl z3.s, z3.s, #8
; CHECK-NEWLOWERING-NEXT: lsl z1.s, z1.s, #8
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
; CHECK-NEWLOWERING-NEXT: add z2.s, z2.s, z4.s
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}