Skip to content

[WebAssembly] Autovec support for dot #123207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISD.def
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ HANDLE_NODETYPE(Wrapper)
HANDLE_NODETYPE(WrapperREL)
HANDLE_NODETYPE(BR_IF)
HANDLE_NODETYPE(BR_TABLE)
HANDLE_NODETYPE(DOT)
HANDLE_NODETYPE(SHUFFLE)
HANDLE_NODETYPE(SWIZZLE)
HANDLE_NODETYPE(VEC_SHL)
Expand Down
124 changes: 124 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsWebAssembly.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down Expand Up @@ -177,6 +178,10 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(

// SIMD-specific configuration
if (Subtarget->hasSIMD128()) {

// Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);

// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);

Expand Down Expand Up @@ -406,6 +411,35 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
return TargetLowering::getPointerMemTy(DL, AS);
}

bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
const IntrinsicInst *I) const {
if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
return true;

EVT VT = EVT::getEVT(I->getType());
auto Op1 = I->getOperand(1);

if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
if (InstructionOpcodeToISD(InputInst->getOpcode()) != ISD::MUL)
return true;

if (isa<Instruction>(InputInst->getOperand(0)) &&
isa<Instruction>(InputInst->getOperand(1))) {
// dot only supports signed inputs but also support lowering unsigned.
if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
cast<Instruction>(InputInst->getOperand(1))->getOpcode())
return true;

EVT Op1VT = EVT::getEVT(Op1->getType());
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
((VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()) ||
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
return false;
}
}
return true;
}

TargetLowering::AtomicExpansionKind
WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
// We have wasm instructions for these
Expand Down Expand Up @@ -2030,6 +2064,94 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
MachinePointerInfo(SV));
}

// Try to lower partial.reduce.add to a dot or fallback to a sequence with
// extmul and adds.
SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
if (N->getConstantOperandVal(0) !=
Intrinsic::experimental_vector_partial_reduce_add)
return SDValue();

assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
SDLoc DL(N);
SDValue Mul = N->getOperand(2);
assert(Mul->getOpcode() == ISD::MUL && "expected mul input");

SDValue ExtendLHS = Mul->getOperand(0);
SDValue ExtendRHS = Mul->getOperand(1);
assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
"expected widening mul");
assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
"expected mul to use the same extend for both operands");

SDValue ExtendInLHS = ExtendLHS->getOperand(0);
SDValue ExtendInRHS = ExtendRHS->getOperand(0);
bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;

if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
if (IsSigned) {
// i32x4.dot_i16x8_s
SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
ExtendInLHS, ExtendInRHS);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
}

unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;

// (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInLHS);
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInRHS);
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInLHS);
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInRHS);

SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
} else {
assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
"expected v16i8 input types");
// Lower to a wider tree, using twice the operations compared to above.
if (IsSigned) {
// Use two dots
unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S;
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S;
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
SDValue DotLHS =
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
SDValue DotRHS =
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}

unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);

SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);

SDValue LowLow = DAG.getNode(LowOpc, DL, MVT::v4i32, MulLow);
SDValue LowHigh = DAG.getNode(LowOpc, DL, MVT::v4i32, MulHigh);
SDValue HighLow = DAG.getNode(HighOpc, DL, MVT::v4i32, MulLow);
SDValue HighHigh = DAG.getNode(HighOpc, DL, MVT::v4i32, MulHigh);

SDValue AddLow = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowLow, HighLow);
SDValue AddHigh = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowHigh, HighHigh);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}
}

SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
SelectionDAG &DAG) const {
MachineFunction &MF = DAG.getMachineFunction();
Expand Down Expand Up @@ -3126,5 +3248,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return performVectorTruncZeroCombine(N, DCI);
case ISD::TRUNCATE:
return performTruncateCombine(N, DCI);
case ISD::INTRINSIC_WO_CHAIN:
return performLowerPartialReduction(N, DCI.DAG);
}
}
2 changes: 2 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class WebAssemblyTargetLowering final : public TargetLowering {
/// right decision when generating code for different targets.
const WebAssemblySubtarget *Subtarget;

bool
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override;
bool shouldScalarizeBinop(SDValue VecOp) const override;
FastISel *createFastISel(FunctionLoweringInfo &FuncInfo,
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -1147,11 +1147,15 @@ def : Pat<(wasm_shr_u
}

// Widening dot product: i32x4.dot_i16x8_s
def dot_t : SDTypeProfile<1, 2, [SDTCisVT<0, v4i32>, SDTCisVT<1, v8i16>, SDTCisVT<2, v8i16>]>;
def wasm_dot : SDNode<"WebAssemblyISD::DOT", dot_t>;
let isCommutable = 1 in
defm DOT : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs), (outs), (ins),
[(set V128:$dst, (int_wasm_dot V128:$lhs, V128:$rhs))],
"i32x4.dot_i16x8_s\t$dst, $lhs, $rhs", "i32x4.dot_i16x8_s",
186>;
def : Pat<(wasm_dot V128:$lhs, V128:$rhs),
(DOT $lhs, $rhs)>;

// Extending multiplication: extmul_{low,high}_P, extmul_high
def extend_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
Expand Down
47 changes: 47 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,53 @@ WebAssemblyTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
return Cost;
}

InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp) const {
InstructionCost Invalid = InstructionCost::getInvalid();
if (!VF.isFixed() || !ST->hasSIMD128())
return Invalid;

InstructionCost Cost(TTI::TCC_Basic);

// Possible options:
// - i16x8.extadd_pairwise_i8x16_sx
// - i32x4.extadd_pairwise_i16x8_sx
// - i32x4.dot_i16x8_s
// Only try to support dot, for now.

if (Opcode != Instruction::Add)
return Invalid;

if (!BinOp || *BinOp != Instruction::Mul)
return Invalid;

if (InputTypeA != InputTypeB)
return Invalid;

if (OpAExtend != OpBExtend)
return Invalid;

EVT InputEVT = EVT::getEVT(InputTypeA);
EVT AccumEVT = EVT::getEVT(AccumType);

// TODO: Add i64 accumulator.
if (AccumEVT != MVT::i32)
return Invalid;

// Signed inputs can lower to dot
if (InputEVT == MVT::i16 && VF.getFixedValue() == 8)
return OpAExtend == TTI::PR_SignExtend ? Cost : Cost * 2;

// Double the size of the lowered sequence.
if (InputEVT == MVT::i8 && VF.getFixedValue() == 16)
return OpAExtend == TTI::PR_SignExtend ? Cost * 2 : Cost * 4;

return Invalid;
}

TTI::ReductionShuffle WebAssemblyTTIImpl::getPreferredExpandedReductionShuffle(
const IntrinsicInst *II) const {

Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
TTI::TargetCostKind CostKind,
unsigned Index, Value *Op0, Value *Op1);

InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
Type *AccumType, ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp = std::nullopt) const;
TTI::ReductionShuffle
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const;

Expand Down
Loading