|
29 | 29 | #include "llvm/IR/DiagnosticInfo.h"
|
30 | 30 | #include "llvm/IR/DiagnosticPrinter.h"
|
31 | 31 | #include "llvm/IR/Function.h"
|
| 32 | +#include "llvm/IR/IntrinsicInst.h" |
32 | 33 | #include "llvm/IR/Intrinsics.h"
|
33 | 34 | #include "llvm/IR/IntrinsicsWebAssembly.h"
|
34 | 35 | #include "llvm/Support/ErrorHandling.h"
|
@@ -177,6 +178,10 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
|
177 | 178 |
|
178 | 179 | // SIMD-specific configuration
|
179 | 180 | if (Subtarget->hasSIMD128()) {
|
| 181 | + |
| 182 | + // Combine partial.reduce.add before legalization gets confused. |
| 183 | + setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN); |
| 184 | + |
180 | 185 | // Combine vector mask reductions into alltrue/anytrue
|
181 | 186 | setTargetDAGCombine(ISD::SETCC);
|
182 | 187 |
|
@@ -406,6 +411,35 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
|
406 | 411 | return TargetLowering::getPointerMemTy(DL, AS);
|
407 | 412 | }
|
408 | 413 |
|
| 414 | +bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic( |
| 415 | + const IntrinsicInst *I) const { |
| 416 | + if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add) |
| 417 | + return true; |
| 418 | + |
| 419 | + EVT VT = EVT::getEVT(I->getType()); |
| 420 | + auto Op1 = I->getOperand(1); |
| 421 | + |
| 422 | + if (auto *InputInst = dyn_cast<Instruction>(Op1)) { |
| 423 | + if (InstructionOpcodeToISD(InputInst->getOpcode()) != ISD::MUL) |
| 424 | + return true; |
| 425 | + |
| 426 | + if (isa<Instruction>(InputInst->getOperand(0)) && |
| 427 | + isa<Instruction>(InputInst->getOperand(1))) { |
| 428 | + // dot only supports signed inputs but also support lowering unsigned. |
| 429 | + if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() != |
| 430 | + cast<Instruction>(InputInst->getOperand(1))->getOpcode()) |
| 431 | + return true; |
| 432 | + |
| 433 | + EVT Op1VT = EVT::getEVT(Op1->getType()); |
| 434 | + if (Op1VT.getVectorElementType() == VT.getVectorElementType() && |
| 435 | + ((VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()) || |
| 436 | + (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount()))) |
| 437 | + return false; |
| 438 | + } |
| 439 | + } |
| 440 | + return true; |
| 441 | +} |
| 442 | + |
409 | 443 | TargetLowering::AtomicExpansionKind
|
410 | 444 | WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
|
411 | 445 | // We have wasm instructions for these
|
@@ -2030,6 +2064,94 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
|
2030 | 2064 | MachinePointerInfo(SV));
|
2031 | 2065 | }
|
2032 | 2066 |
|
| 2067 | +// Try to lower partial.reduce.add to a dot or fallback to a sequence with |
| 2068 | +// extmul and adds. |
| 2069 | +SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) { |
| 2070 | + assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN); |
| 2071 | + if (N->getConstantOperandVal(0) != |
| 2072 | + Intrinsic::experimental_vector_partial_reduce_add) |
| 2073 | + return SDValue(); |
| 2074 | + |
| 2075 | + assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32"); |
| 2076 | + SDLoc DL(N); |
| 2077 | + SDValue Mul = N->getOperand(2); |
| 2078 | + assert(Mul->getOpcode() == ISD::MUL && "expected mul input"); |
| 2079 | + |
| 2080 | + SDValue ExtendLHS = Mul->getOperand(0); |
| 2081 | + SDValue ExtendRHS = Mul->getOperand(1); |
| 2082 | + assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) && |
| 2083 | + ISD::isExtOpcode(ExtendRHS.getOpcode())) && |
| 2084 | + "expected widening mul"); |
| 2085 | + assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() && |
| 2086 | + "expected mul to use the same extend for both operands"); |
| 2087 | + |
| 2088 | + SDValue ExtendInLHS = ExtendLHS->getOperand(0); |
| 2089 | + SDValue ExtendInRHS = ExtendRHS->getOperand(0); |
| 2090 | + bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND; |
| 2091 | + |
| 2092 | + if (ExtendInLHS->getValueType(0) == MVT::v8i16) { |
| 2093 | + if (IsSigned) { |
| 2094 | + // i32x4.dot_i16x8_s |
| 2095 | + SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, |
| 2096 | + ExtendInLHS, ExtendInRHS); |
| 2097 | + return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot); |
| 2098 | + } |
| 2099 | + |
| 2100 | + unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U; |
| 2101 | + unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U; |
| 2102 | + |
| 2103 | + // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs))) |
| 2104 | + SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInLHS); |
| 2105 | + SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInRHS); |
| 2106 | + SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInLHS); |
| 2107 | + SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInRHS); |
| 2108 | + |
| 2109 | + SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS); |
| 2110 | + SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS); |
| 2111 | + SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh); |
| 2112 | + return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); |
| 2113 | + } else { |
| 2114 | + assert(ExtendInLHS->getValueType(0) == MVT::v16i8 && |
| 2115 | + "expected v16i8 input types"); |
| 2116 | + // Lower to a wider tree, using twice the operations compared to above. |
| 2117 | + if (IsSigned) { |
| 2118 | + // Use two dots |
| 2119 | + unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S; |
| 2120 | + unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S; |
| 2121 | + SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS); |
| 2122 | + SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS); |
| 2123 | + SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS); |
| 2124 | + SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS); |
| 2125 | + SDValue DotLHS = |
| 2126 | + DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS); |
| 2127 | + SDValue DotRHS = |
| 2128 | + DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS); |
| 2129 | + SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS); |
| 2130 | + return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); |
| 2131 | + } |
| 2132 | + |
| 2133 | + unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U; |
| 2134 | + unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U; |
| 2135 | + SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS); |
| 2136 | + SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS); |
| 2137 | + SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS); |
| 2138 | + SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS); |
| 2139 | + |
| 2140 | + SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS); |
| 2141 | + SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS); |
| 2142 | + |
| 2143 | + SDValue LowLow = DAG.getNode(LowOpc, DL, MVT::v4i32, MulLow); |
| 2144 | + SDValue LowHigh = DAG.getNode(LowOpc, DL, MVT::v4i32, MulHigh); |
| 2145 | + SDValue HighLow = DAG.getNode(HighOpc, DL, MVT::v4i32, MulLow); |
| 2146 | + SDValue HighHigh = DAG.getNode(HighOpc, DL, MVT::v4i32, MulHigh); |
| 2147 | + |
| 2148 | + SDValue AddLow = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowLow, HighLow); |
| 2149 | + SDValue AddHigh = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowHigh, HighHigh); |
| 2150 | + SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh); |
| 2151 | + return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); |
| 2152 | + } |
| 2153 | +} |
| 2154 | + |
2033 | 2155 | SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
|
2034 | 2156 | SelectionDAG &DAG) const {
|
2035 | 2157 | MachineFunction &MF = DAG.getMachineFunction();
|
@@ -3126,5 +3248,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
|
3126 | 3248 | return performVectorTruncZeroCombine(N, DCI);
|
3127 | 3249 | case ISD::TRUNCATE:
|
3128 | 3250 | return performTruncateCombine(N, DCI);
|
| 3251 | + case ISD::INTRINSIC_WO_CHAIN: |
| 3252 | + return performLowerPartialReduction(N, DCI.DAG); |
3129 | 3253 | }
|
3130 | 3254 | }
|
0 commit comments