-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NVPTX] Optimize v2x16 BUILD_VECTORs to PRMT #116675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Fraser Cormack (frasercrmck) ChangesWhen two 16-bit values are combined into a v2x16 vector, and those values are truncated come from 32-bit values, a PRMT instruction can save registers by selecting bytes directly from the original 32-bit values. We do this during a post-legalize DAG combine, as these opportunities are typically only exposed after the BUILD_VECTOR's operands have been legalized. Additionally, if the 32-bit values are right-shifted, we can fold in the shift by selecting higher bytes with PRMT. Only logical right-shifts by 16 are supported (for now) since those are the only situations seen in practice. Right shifts by 16 often come up during the legalization of EXTRACT_VECTOR_ELT. This idea was brought up in a PR comment by @Artem-B. Patch is 23.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116675.diff 6 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 4ad0200ca5cf83..da38aff1efefd2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -762,7 +762,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// We have some custom DAG combine patterns for these nodes
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM,
- ISD::VSELECT});
+ ISD::VSELECT, ISD::BUILD_VECTOR});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -6176,6 +6176,57 @@ static SDValue PerformLOADCombine(SDNode *N,
DL);
}
+static SDValue
+PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+ auto VT = N->getValueType(0);
+ if (!DCI.isAfterLegalizeDAG() || !Isv2x16VT(VT))
+ return SDValue();
+
+ auto Op0 = N->getOperand(0);
+ auto Op1 = N->getOperand(1);
+
+ // Start out by assuming we want to take the lower 2 bytes of each i32
+ // operand.
+ uint64_t Op0Bytes = 0x10;
+ uint64_t Op1Bytes = 0x54;
+
+ std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
+ {&Op1, &Op1Bytes}};
+
+ // Check that each operand is an i16, truncated from an i32 operand. We'll
+ // select individual bytes from those original operands. Optionally, fold in a
+ // shift right of that original operand.
+ for (auto &[Op, OpBytes] : OpData) {
+ // Eat up any bitcast
+ if (Op->getOpcode() == ISD::BITCAST)
+ *Op = Op->getOperand(0);
+
+ if (Op->getValueType() != MVT::i16 || Op->getOpcode() != ISD::TRUNCATE ||
+ Op->getOperand(0).getValueType() != MVT::i32)
+ return SDValue();
+
+ *Op = Op->getOperand(0);
+
+ // Optionally, fold in a shift-right of the original operand and permute
+ // the two higher bytes from the shifted operand
+ if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Op->getOperand(1))) {
+ if (cast<ConstantSDNode>(Op->getOperand(1))->getZExtValue() == 16) {
+ *OpBytes += 0x22;
+ *Op = Op->getOperand(0);
+ }
+ }
+ }
+
+ SDLoc DL(N);
+ auto &DAG = DCI.DAG;
+
+ auto PRMT = DAG.getNode(
+ NVPTXISD::PRMT, DL, MVT::v4i8,
+ {Op0, Op1, DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32),
+ DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
+ return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
+}
+
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -6210,6 +6261,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformEXTRACTCombine(N, DCI);
case ISD::VSELECT:
return PerformVSELECTCombine(N, DCI);
+ case ISD::BUILD_VECTOR:
+ return PerformBUILD_VECTORCombine(N, DCI);
}
return SDValue();
}
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 80815b3ca37c05..bd45d85d393000 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -159,8 +159,8 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70-LABEL: test_faddx2(
; SM70: {
; SM70-NEXT: .reg .pred %p<3>;
-; SM70-NEXT: .reg .b16 %rs<13>;
-; SM70-NEXT: .reg .b32 %r<24>;
+; SM70-NEXT: .reg .b16 %rs<9>;
+; SM70-NEXT: .reg .b32 %r<25>;
; SM70-NEXT: .reg .f32 %f<7>;
; SM70-EMPTY:
; SM70-NEXT: // %bb.0:
@@ -182,7 +182,6 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70-NEXT: setp.nan.f32 %p1, %f3, %f3;
; SM70-NEXT: or.b32 %r11, %r7, 4194304;
; SM70-NEXT: selp.b32 %r12, %r11, %r10, %p1;
-; SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs7}, %r12; }
; SM70-NEXT: cvt.u32.u16 %r13, %rs1;
; SM70-NEXT: shl.b32 %r14, %r13, 16;
; SM70-NEXT: mov.b32 %f4, %r14;
@@ -197,8 +196,7 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70-NEXT: setp.nan.f32 %p2, %f6, %f6;
; SM70-NEXT: or.b32 %r21, %r17, 4194304;
; SM70-NEXT: selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs11}, %r22; }
-; SM70-NEXT: mov.b32 %r23, {%rs11, %rs7};
+; SM70-NEXT: prmt.b32 %r23, %r22, %r12, 0x7632U;
; SM70-NEXT: st.param.b32 [func_retval0], %r23;
; SM70-NEXT: ret;
;
@@ -266,8 +264,8 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70-LABEL: test_fsubx2(
; SM70: {
; SM70-NEXT: .reg .pred %p<3>;
-; SM70-NEXT: .reg .b16 %rs<13>;
-; SM70-NEXT: .reg .b32 %r<24>;
+; SM70-NEXT: .reg .b16 %rs<9>;
+; SM70-NEXT: .reg .b32 %r<25>;
; SM70-NEXT: .reg .f32 %f<7>;
; SM70-EMPTY:
; SM70-NEXT: // %bb.0:
@@ -289,7 +287,6 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70-NEXT: setp.nan.f32 %p1, %f3, %f3;
; SM70-NEXT: or.b32 %r11, %r7, 4194304;
; SM70-NEXT: selp.b32 %r12, %r11, %r10, %p1;
-; SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs7}, %r12; }
; SM70-NEXT: cvt.u32.u16 %r13, %rs1;
; SM70-NEXT: shl.b32 %r14, %r13, 16;
; SM70-NEXT: mov.b32 %f4, %r14;
@@ -304,8 +301,7 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70-NEXT: setp.nan.f32 %p2, %f6, %f6;
; SM70-NEXT: or.b32 %r21, %r17, 4194304;
; SM70-NEXT: selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs11}, %r22; }
-; SM70-NEXT: mov.b32 %r23, {%rs11, %rs7};
+; SM70-NEXT: prmt.b32 %r23, %r22, %r12, 0x7632U;
; SM70-NEXT: st.param.b32 [func_retval0], %r23;
; SM70-NEXT: ret;
;
@@ -373,8 +369,8 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70-LABEL: test_fmulx2(
; SM70: {
; SM70-NEXT: .reg .pred %p<3>;
-; SM70-NEXT: .reg .b16 %rs<13>;
-; SM70-NEXT: .reg .b32 %r<24>;
+; SM70-NEXT: .reg .b16 %rs<9>;
+; SM70-NEXT: .reg .b32 %r<25>;
; SM70-NEXT: .reg .f32 %f<7>;
; SM70-EMPTY:
; SM70-NEXT: // %bb.0:
@@ -396,7 +392,6 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70-NEXT: setp.nan.f32 %p1, %f3, %f3;
; SM70-NEXT: or.b32 %r11, %r7, 4194304;
; SM70-NEXT: selp.b32 %r12, %r11, %r10, %p1;
-; SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs7}, %r12; }
; SM70-NEXT: cvt.u32.u16 %r13, %rs1;
; SM70-NEXT: shl.b32 %r14, %r13, 16;
; SM70-NEXT: mov.b32 %f4, %r14;
@@ -411,8 +406,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70-NEXT: setp.nan.f32 %p2, %f6, %f6;
; SM70-NEXT: or.b32 %r21, %r17, 4194304;
; SM70-NEXT: selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs11}, %r22; }
-; SM70-NEXT: mov.b32 %r23, {%rs11, %rs7};
+; SM70-NEXT: prmt.b32 %r23, %r22, %r12, 0x7632U;
; SM70-NEXT: st.param.b32 [func_retval0], %r23;
; SM70-NEXT: ret;
;
@@ -480,8 +474,8 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70-LABEL: test_fdiv(
; SM70: {
; SM70-NEXT: .reg .pred %p<3>;
-; SM70-NEXT: .reg .b16 %rs<13>;
-; SM70-NEXT: .reg .b32 %r<24>;
+; SM70-NEXT: .reg .b16 %rs<9>;
+; SM70-NEXT: .reg .b32 %r<25>;
; SM70-NEXT: .reg .f32 %f<7>;
; SM70-EMPTY:
; SM70-NEXT: // %bb.0:
@@ -503,7 +497,6 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70-NEXT: setp.nan.f32 %p1, %f3, %f3;
; SM70-NEXT: or.b32 %r11, %r7, 4194304;
; SM70-NEXT: selp.b32 %r12, %r11, %r10, %p1;
-; SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs7}, %r12; }
; SM70-NEXT: cvt.u32.u16 %r13, %rs1;
; SM70-NEXT: shl.b32 %r14, %r13, 16;
; SM70-NEXT: mov.b32 %f4, %r14;
@@ -518,8 +511,7 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70-NEXT: setp.nan.f32 %p2, %f6, %f6;
; SM70-NEXT: or.b32 %r21, %r17, 4194304;
; SM70-NEXT: selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs11}, %r22; }
-; SM70-NEXT: mov.b32 %r23, {%rs11, %rs7};
+; SM70-NEXT: prmt.b32 %r23, %r22, %r12, 0x7632U;
; SM70-NEXT: st.param.b32 [func_retval0], %r23;
; SM70-NEXT: ret;
;
@@ -1724,8 +1716,8 @@ define <2 x bfloat> @test_maxnum_v2(<2 x bfloat> %a, <2 x bfloat> %b) {
; SM70-LABEL: test_maxnum_v2(
; SM70: {
; SM70-NEXT: .reg .pred %p<3>;
-; SM70-NEXT: .reg .b16 %rs<13>;
-; SM70-NEXT: .reg .b32 %r<24>;
+; SM70-NEXT: .reg .b16 %rs<9>;
+; SM70-NEXT: .reg .b32 %r<25>;
; SM70-NEXT: .reg .f32 %f<7>;
; SM70-EMPTY:
; SM70-NEXT: // %bb.0:
@@ -1747,7 +1739,6 @@ define <2 x bfloat> @test_maxnum_v2(<2 x bfloat> %a, <2 x bfloat> %b) {
; SM70-NEXT: setp.nan.f32 %p1, %f3, %f3;
; SM70-NEXT: or.b32 %r11, %r7, 4194304;
; SM70-NEXT: selp.b32 %r12, %r11, %r10, %p1;
-; SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs7}, %r12; }
; SM70-NEXT: cvt.u32.u16 %r13, %rs1;
; SM70-NEXT: shl.b32 %r14, %r13, 16;
; SM70-NEXT: mov.b32 %f4, %r14;
@@ -1762,8 +1753,7 @@ define <2 x bfloat> @test_maxnum_v2(<2 x bfloat> %a, <2 x bfloat> %b) {
; SM70-NEXT: setp.nan.f32 %p2, %f6, %f6;
; SM70-NEXT: or.b32 %r21, %r17, 4194304;
; SM70-NEXT: selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs11}, %r22; }
-; SM70-NEXT: mov.b32 %r23, {%rs11, %rs7};
+; SM70-NEXT: prmt.b32 %r23, %r22, %r12, 0x7632U;
; SM70-NEXT: st.param.b32 [func_retval0], %r23;
; SM70-NEXT: ret;
;
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll b/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
index 8cc4548f6e85e0..f22d2e7a6d3c9d 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
@@ -1050,8 +1050,8 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
; CHECK-SM70-LABEL: fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(
; CHECK-SM70: {
; CHECK-SM70-NEXT: .reg .pred %p<9>;
-; CHECK-SM70-NEXT: .reg .b16 %rs<25>;
-; CHECK-SM70-NEXT: .reg .b32 %r<61>;
+; CHECK-SM70-NEXT: .reg .b16 %rs<21>;
+; CHECK-SM70-NEXT: .reg .b32 %r<62>;
; CHECK-SM70-NEXT: .reg .f32 %f<19>;
; CHECK-SM70-EMPTY:
; CHECK-SM70-NEXT: // %bb.0:
@@ -1134,7 +1134,6 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
; CHECK-SM70-NEXT: setp.nan.f32 %p7, %f15, %f15;
; CHECK-SM70-NEXT: or.b32 %r49, %r45, 4194304;
; CHECK-SM70-NEXT: selp.b32 %r50, %r49, %r48, %p7;
-; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs20}, %r50; }
; CHECK-SM70-NEXT: cvt.u32.u16 %r51, %rs17;
; CHECK-SM70-NEXT: shl.b32 %r52, %r51, 16;
; CHECK-SM70-NEXT: mov.b32 %f16, %r52;
@@ -1148,8 +1147,7 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
; CHECK-SM70-NEXT: setp.nan.f32 %p8, %f18, %f18;
; CHECK-SM70-NEXT: or.b32 %r58, %r54, 4194304;
; CHECK-SM70-NEXT: selp.b32 %r59, %r58, %r57, %p8;
-; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs23}, %r59; }
-; CHECK-SM70-NEXT: mov.b32 %r60, {%rs23, %rs20};
+; CHECK-SM70-NEXT: prmt.b32 %r60, %r59, %r50, 0x7632U;
; CHECK-SM70-NEXT: st.param.b32 [func_retval0], %r60;
; CHECK-SM70-NEXT: ret;
%1 = fmul <2 x bfloat> %a, %b
@@ -1189,8 +1187,8 @@ define <2 x bfloat> @fma_bf16x2_expanded_maxnum_no_nans(<2 x bfloat> %a, <2 x bf
; CHECK-SM70-LABEL: fma_bf16x2_expanded_maxnum_no_nans(
; CHECK-SM70: {
; CHECK-SM70-NEXT: .reg .pred %p<5>;
-; CHECK-SM70-NEXT: .reg .b16 %rs<17>;
-; CHECK-SM70-NEXT: .reg .b32 %r<43>;
+; CHECK-SM70-NEXT: .reg .b16 %rs<13>;
+; CHECK-SM70-NEXT: .reg .b32 %r<44>;
; CHECK-SM70-NEXT: .reg .f32 %f<13>;
; CHECK-SM70-EMPTY:
; CHECK-SM70-NEXT: // %bb.0:
@@ -1244,7 +1242,6 @@ define <2 x bfloat> @fma_bf16x2_expanded_maxnum_no_nans(<2 x bfloat> %a, <2 x bf
; CHECK-SM70-NEXT: setp.nan.f32 %p3, %f10, %f10;
; CHECK-SM70-NEXT: or.b32 %r33, %r29, 4194304;
; CHECK-SM70-NEXT: selp.b32 %r34, %r33, %r32, %p3;
-; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs13}, %r34; }
; CHECK-SM70-NEXT: and.b32 %r35, %r15, -65536;
; CHECK-SM70-NEXT: mov.b32 %f11, %r35;
; CHECK-SM70-NEXT: max.f32 %f12, %f11, 0f00000000;
@@ -1255,8 +1252,7 @@ define <2 x bfloat> @fma_bf16x2_expanded_maxnum_no_nans(<2 x bfloat> %a, <2 x bf
; CHECK-SM70-NEXT: setp.nan.f32 %p4, %f12, %f12;
; CHECK-SM70-NEXT: or.b32 %r40, %r36, 4194304;
; CHECK-SM70-NEXT: selp.b32 %r41, %r40, %r39, %p4;
-; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs15}, %r41; }
-; CHECK-SM70-NEXT: mov.b32 %r42, {%rs15, %rs13};
+; CHECK-SM70-NEXT: prmt.b32 %r42, %r41, %r34, 0x7632U;
; CHECK-SM70-NEXT: st.param.b32 [func_retval0], %r42;
; CHECK-SM70-NEXT: ret;
%1 = fmul <2 x bfloat> %a, %b
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll b/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
index 16219aa9da0950..0f24160af57af1 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
@@ -715,8 +715,8 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
; CHECK-SM70-LABEL: fma_bf16x2_no_nans_multiple_uses_of_fma(
; CHECK-SM70: {
; CHECK-SM70-NEXT: .reg .pred %p<7>;
-; CHECK-SM70-NEXT: .reg .b16 %rs<17>;
-; CHECK-SM70-NEXT: .reg .b32 %r<57>;
+; CHECK-SM70-NEXT: .reg .b16 %rs<13>;
+; CHECK-SM70-NEXT: .reg .b32 %r<58>;
; CHECK-SM70-NEXT: .reg .f32 %f<17>;
; CHECK-SM70-EMPTY:
; CHECK-SM70-NEXT: // %bb.0:
@@ -790,7 +790,6 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
; CHECK-SM70-NEXT: setp.nan.f32 %p5, %f14, %f14;
; CHECK-SM70-NEXT: or.b32 %r47, %r43, 4194304;
; CHECK-SM70-NEXT: selp.b32 %r48, %r47, %r46, %p5;
-; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs13}, %r48; }
; CHECK-SM70-NEXT: and.b32 %r49, %r34, -65536;
; CHECK-SM70-NEXT: mov.b32 %f15, %r49;
; CHECK-SM70-NEXT: add.f32 %f16, %f15, %f9;
@@ -801,8 +800,7 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
; CHECK-SM70-NEXT: setp.nan.f32 %p6, %f16, %f16;
; CHECK-SM70-NEXT: or.b32 %r54, %r50, 4194304;
; CHECK-SM70-NEXT: selp.b32 %r55, %r54, %r53, %p6;
-; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs15}, %r55; }
-; CHECK-SM70-NEXT: mov.b32 %r56, {%rs15, %rs13};
+; CHECK-SM70-NEXT: prmt.b32 %r56, %r55, %r48, 0x7632U;
; CHECK-SM70-NEXT: st.param.b32 [func_retval0], %r56;
; CHECK-SM70-NEXT: ret;
%1 = call <2 x bfloat> @llvm.fma.bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
@@ -841,8 +839,8 @@ define <2 x bfloat> @fma_bf16x2_maxnum_no_nans(<2 x bfloat> %a, <2 x bfloat> %b,
; CHECK-SM70-LABEL: fma_bf16x2_maxnum_no_nans(
; CHECK-SM70: {
; CHECK-SM70-NEXT: .reg .pred %p<5>;
-; CHECK-SM70-NEXT: .reg .b16 %rs<17>;
-; CHECK-SM70-NEXT: .reg .b32 %r<43>;
+; CHECK-SM70-NEXT: .reg .b16 %rs<13>;
+; CHECK-SM70-NEXT: .reg .b32 %r<44>;
; CHECK-SM70-NEXT: .reg .f32 %f<13>;
; CHECK-SM70-EMPTY:
; CHECK-SM70-NEXT: // %bb.0:
@@ -896,7 +894,6 @@ define <2 x bfloat> @fma_bf16x2_maxnum_no_nans(<2 x bfloat> %a, <2 x bfloat> %b,
; CHECK-SM70-NEXT: setp.nan.f32 %p3, %f10, %f10;
; CHECK-SM70-NEXT: or.b32 %r33, %r29, 4194304;
; CHECK-SM70-NEXT: selp.b32 %r34, %r33, %r32, %p3;
-; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs13}, %r34; }
; CHECK-SM70-NEXT: and.b32 %r35, %r15, -65536;
; CHECK-SM70-NEXT: mov.b32 %f11, %r35;
; CHECK-SM70-NEXT: max.f32 %f12, %f11, 0f00000000;
@@ -907,8 +904,7 @@ define <2 x bfloat> @fma_bf16x2_maxnum_no_nans(<2 x bfloat> %a, <2 x bfloat> %b,
; CHECK-SM70-NEXT: setp.nan.f32 %p4, %f12, %f12;
; CHECK-SM70-NEXT: or.b32 %r40, %r36, 4194304;
; CHECK-SM70-NEXT: selp.b32 %r41, %r40, %r39, %p4;
-; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs15}, %r41; }
-; CHECK-SM70-NEXT: mov.b32 %r42, {%rs15, %rs13};
+; CHECK-SM70-NEXT: prmt.b32 %r42, %r41, %r34, 0x7632U;
; CHECK-SM70-NEXT: st.param.b32 [func_retval0], %r42;
; CHECK-SM70-NEXT: ret;
%1 = call <2 x bfloat> @llvm.fma.bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll b/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
index af21bada7783be..63c2b96be838fa 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
@@ -785,8 +785,8 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
; CHECK-SM70-LABEL: fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(
; CHECK-SM70: {
; CHECK-SM70-NEXT: .reg .pred %p<9>;
-; CHECK-SM70-NEXT: .reg .b16 %rs<25>;
-; CHECK-SM70-NEXT: .reg .b32 %r<61>;
+; CHECK-SM70-NEXT: .reg .b16 %rs<21>;
+; CHECK-SM70-NEXT: .reg .b32 %r<62>;
; CHECK-SM70-NEXT: .reg .f32 %f<19>;
; CHECK-SM70-EMPTY:
; CHECK-SM70-NEXT: // %bb.0:
@@ -869,7 +869,6 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
; CHECK-SM70-NEXT: setp.nan.f32 %p7, %f15, %f15;
; CHECK-SM70-NEXT: or.b32 %r49, %r45, 4194304;
; CHECK-SM70-NEXT: selp.b32 %r50, %r49, %r48, %p7;
-; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs20}, %r50; }
; CHECK-SM70-NEXT: cvt.u32.u16 %r51, %rs17;
; CHECK-SM70-NEXT: shl.b32 %r52, %r51, 16;
; CHECK-SM70-NEXT: mov.b32 %f16, %r52;
@@ -883,8 +882,7 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
; CHECK-SM70-NEXT: setp.nan.f32 %p8, %f18, %f18;
; CHECK-SM70-NEXT: or.b32 %r58, %r54, 4194304;
; CHECK-SM70-NEXT: selp.b32 %r59, %r58, %r57, %p8;
-; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs23}, %r59; }
-; CHECK-SM70-NEXT: mov.b32 %r60, {%rs23, %rs20};
+; CHECK-SM70-NEXT: prmt.b32 %r60, %r59, %r50, 0x7632U;
; CHECK-SM70-NEXT: st.param.b32 [func_retval0], %r60;
; CHECK-SM70-NEXT: ret;
%1 = fmul fast <2 x bfloat> %a, %b
@@ -924,8 +922,8 @@ define <2 x bfloat> @fma_bf16x2_expanded_maxnum_no_nans(<2 x bfloat> %a, <2 x bf
; CHECK-SM70-LABEL: fma_bf16x2_expanded_maxnum_no_nans(
; CHECK-SM70: {
; CHECK-SM70-NEXT: .reg .pred %p<5>;
-; CHECK-SM70-NEXT: .reg .b16 %rs<17>;
-; CHECK-SM70-NEXT: .reg .b32 %r<43>;
+; CHECK-SM70-NEXT: .reg .b16 %rs<13>;
+; CHECK-SM70-NEXT: .reg .b32 %r<44>;
; CHECK-SM70-NEXT: .reg .f32 %f<13>;
; CHECK-SM70-EMPTY:
; CHECK-SM70-NEXT: // %bb.0:
@@ -979,7 +977,6 @@ define <2 x bfloat> @fma_bf16x2_expanded_maxnum_no_nans(<2 x bfloat> %a, <2 x bf
; CHECK-SM70-NEXT: setp.nan.f32 %p3, %f10, %f10;
; CHECK-SM70-NEXT: or.b32 %r33, %r29, 4194304;
; CHECK-SM70-NEXT: selp.b32 %r34, %r33, %r32, %p3;
-; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs13}, %r34; }
; CHECK-SM70-NEXT: and.b32 %r35, %r15, -65536;
; CHECK-SM70-NEXT: mov.b32 %f11, %r35;
; CHECK-SM70-NEXT: max.f32 %f12, %f11, 0f00000000;
@@ -990,8 +987,7 @@ define <2 x bfloat> @fma_bf16x2_expanded_maxnum_no_nans(<2 x bfloat> %a, <2 x bf
; CHECK-SM70-NEXT: setp.nan.f32 %p4, %f12, %f12;
; CHECK-SM70-NEXT: or.b32 %r40, %r36, 4194304;
; CHECK-SM70-NEXT: selp.b32 %r41, %r40, %r39, %p4;
-; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs15}, %r41; }
-; CHECK-SM70-NEXT: mov.b32 %r42, {%rs15, %rs13};
+; CHECK-SM70-NEXT: prmt.b32 %r42, %r41, %r34, 0x7632U;
; CHECK-SM70-NEXT: st.param.b32 [func_retval0], %r42;
; CHECK-SM70-NEXT: ret;
%1 = fmul fast <2 x bfloat> %a, %b
@@ -1710,8 +1706,8 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
; CHECK-SM70-LABEL: fma_bf16x2_no_nans_multiple_uses_of_fma(
; CHECK-SM70: {
; CHECK-SM70-NEXT: .reg .pred %p<7>;
-; CHECK-SM70-NEXT: .reg .b16 %rs<17>;
-; CHECK-SM70-NEXT: .reg .b32 %r<57>;
+; CHECK-SM70-NEXT: .reg .b16 %rs<13>;
+; CHECK-SM70-NEXT: .reg .b32 %r<58>;
; CHECK-SM70-NEXT: .reg .f32 %f<17>;
; CHECK-SM70-EMPTY:
; CHECK-SM70-NEXT: // %bb.0:
@@ -1785,7 +1781,6 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
; CHECK-SM70-NEXT: setp.nan.f32 %p5, %f14, %f14;
; ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. LGTM with a few minor nits.
if (Op->getOpcode() == ISD::BITCAST) | ||
*Op = Op->getOperand(0); | ||
|
||
if (Op->getValueType() != MVT::i16 || Op->getOpcode() != ISD::TRUNCATE || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be good to check that some of these instructions only have a single use, otherwise this transform is not a clear win and may increase register pressure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll write some tests to check what happens. Artem in another comment thinks that it might be okay if the instructions aren't one-use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried some simple tests with multiple uses of the truncate and/or the original value.
When you reuse the truncate you appear to increase register pressure (though the SASS remains the same): https://godbolt.org/z/MT8PqshsW
When you reuse the original value the register pressure looks better, indicating the PRMT is worthwhile. Though the SASS is the same: https://godbolt.org/z/MqdbT6W59
When you reuse both, the register pressure is still worse, though the SASS remains the same: https://godbolt.org/z/xj46avWqY
So, even though the SASS remains the same in these simple examples, it indicates we should probably bail out if the truncate has multiple uses. Multiple uses of the original value appears to be alright.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes I should add, those links are using ptxas v11.5.0. It's the same result with the most recent 12.6.1; the SASS is always the same between the PRMT and non-PRMT versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've pushed a change that forbids multiple uses of the truncate, and have added tests for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the register pressure is still worse, though the SASS remains the same
What register pressure are you referring to here? Is it the number of declared registers:
.reg .b32 %r<number of registers declared>;
Do we know of a case where the SASS differs due to this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I was referring to the number declared registers. I thought that's what you were referring to as it's the only thing we have "control" of.
I've only tried the examples I linked to above but no I've not seen a difference in the actual SASS. I could imagine that more PTX registers could manifest in worse SASS in more complex examples? It's not much to go on, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the tests we're consistently declaring four less b16
registers, and one more .b32
. It looks like a net win to me, even if it does not matter for SASS. It means we're doing less unnecessary stuff with .b16
intermediate values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, this is great! Thanks for following-up on @Artem-B's comment :)
; CHECK-SM70-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs23}, %r59; } | ||
; CHECK-SM70-NEXT: mov.b32 %r60, {%rs23, %rs20}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we now able to delete the code that generates this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so because it's still used for (trunc (srl s, 16))
or (extractelt $vec, 1)
. Perhaps if we generally matched both of those to PRMTs we could remove the code, but I suspect we'll always need the option to fall back to these patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(trunc (srl s, 16))
or(extractelt $vec, 1)
Those should be partially converted to prmt
, too. The part that moves bits in multiples of 8 to the LSB of i32 maps to permute, and trunc
would just be a regular truncating move. Does not have to be done in this patch, but if the change is trivial, it may fit here, too. Up to you.
// Optionally, fold in a shift-right of the original operand and permute | ||
// the two higher bytes from the shifted operand | ||
if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Op->getOperand(1))) { | ||
if (cast<ConstantSDNode>(Op->getOperand(1))->getZExtValue() == 16) { | ||
*OpBytes += 0x22; | ||
*Op = Op->getOperand(0); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a separate optimization to me: prmt a (shr b [8|16|24]) x
-> prmt a b y
.
Could we implement a separate DAG combine to do this? That way, LowerBUILD_VECTOR
(and eventually LowerINSERT_VECTOR_ELT
and LowerEXTRACT_VECTOR_ELT
once they begin to use prmt
) can use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Piggybacking off of this, if we do perform prmt a (shr b [8|16|24]) x
-> prmt a b y
separately, then should this be a custom lowering instead of a DAG combine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a separate optimization to me: prmt a (shr b [8|16|24]) x -> prmt a b y.
One way we could generalize that is via two-step process:
- convert shifts and truncates by 8/16/24, vector extraction of i8/i16 into
prmt
, This give us convenient access to an i32 and a mask of valid bits in it. E.g.shr b,16
->prmt b, 0, 0x4410
. - then combine individual i32s with partial values into one. E.g.
or(prmt(a,0,0x4410), prmt(b,0,0x1044), 16))
->prmt(a,b,0x5410)
. It would be applicable toprmt
itself,or
,build_vector
,insert_vector_elt
, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My original approach was custom-lower v2x16 BUILD_VECTORs to PRMT, but I found it was getting too complicated to identify the cases that were optimal to use with PRMT. As discussed in another comment thread, unless the original values are i32, this optimization broadly isn't worth it. During lowering of BUILD_VECTOR, we often don't as easily see the 16-bit-from-32-bit values that make it worthwhile. My concern with being too eager about BUILD_VECTOR, is that there will necessarily be cases where we need to undo a suboptimal PRMT, which feels wrong to me.
I imagine that Artem's idea could work, however. I suspect we'll still have to undo an eagerly matched PRMT, though. Perhaps we'd have to be more careful with one-use nodes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you see this idea working in the case of something like generic_2xi16
here?
In this case, it's better to extract to actual 16-bit registers from the vector using mov.b32 {%rs1, %rs2}, %r1
, because we're doing scalarized 16-bit operations on them. We can't detect that at the point at which we do LowerEXTRACT_VECTOR_ELEMENT
. Well we can, but it's not common to have to look at uses when lowering nodes, and it's a bit of a red flag to me. The problem with eagerly creating PRMTs when lowering the vector extracts is that it gives us these awkward 32-bit registers, which we have to undo.
We'd probably have to do what NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT
is doing and try and find two truncate
s from two PRMT
s from the same original vector and replace them all with a single NVPTX::I32ToV2I16
. This starts to sound a little strange, especially since we haven't proven that using PRMTs in this way is going to bring any tangible benefits.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The '32-bit-source + mask' intermadiate value approach would be able to preserve that nice example. Nothing stops us from lowering a pair of SM(value, 0xffff), SM(value, 0xffff0000)
as mov.s32 {rs1, rs2}, value
when we need two 16-bit values, but use prmt
when the destination is 32-bit. E.g. if your example above would only swap those 16-bit parts, without addition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose so, yes.
I was also wondering whether we'd need a way of encapsulating the idea of an "unused" byte for maximum effect, as in the example of truncation: i16 = trunc (i32 a)
-> prmt a, b, 0xXX10
. There's no concrete byte of a
or b
to select for the truncated-out bytes, nor any concrete value of b
- both can be anything we want. If we choose any concrete values or indices, it might hinder or confuse optimal combines/lowering. Using something like undef
as b
would work, and maybe selecting any byte from an undef
operand would give us enough info to know we can replace them with anything.
In any case, I won't have the time to dedicate to trying out any of these ideas, I'm afraid. If you want to take it on, go ahead! We could either merge this PR as-is, or just abandon it. I'm easy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can always use zero as one of the PRMT inputs, or sign-extend upper bits.
This particular change is in the nice-to-have category but it's not particularly important. It's more of a "how we could produce slightly better code with a cleaner/nicer implementation", but given that things already work reasonably well the upside is very limited. It may be a suitable starter project for someone getting up to speed with LLVM/NVPTX in the future.
// Optionally, fold in a shift-right of the original operand and permute | ||
// the two higher bytes from the shifted operand | ||
if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Op->getOperand(1))) { | ||
if (cast<ConstantSDNode>(Op->getOperand(1))->getZExtValue() == 16) { | ||
*OpBytes += 0x22; | ||
*Op = Op->getOperand(0); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a separate optimization to me: prmt a (shr b [8|16|24]) x -> prmt a b y.
One way we could generalize that is via two-step process:
- convert shifts and truncates by 8/16/24, vector extraction of i8/i16 into
prmt
, This give us convenient access to an i32 and a mask of valid bits in it. E.g.shr b,16
->prmt b, 0, 0x4410
. - then combine individual i32s with partial values into one. E.g.
or(prmt(a,0,0x4410), prmt(b,0,0x1044), 16))
->prmt(a,b,0x5410)
. It would be applicable toprmt
itself,or
,build_vector
,insert_vector_elt
, etc.
When two 16-bit values are combined into a v2x16 vector, and those values are truncated come from 32-bit values, a PRMT instruction can save registers by selecting bytes directly from the original 32-bit values. We do this during a post-legalize DAG combine, as these opportunities are typically only exposed after the BUILD_VECTOR's operands have been legalized. Additionally, if the 32-bit values are right-shifted, we can fold in the shift by selecting higher bytes with PRMT. Only logical right-shifts by 16 are supported (for now) since those are the only situations seen in practice. Right shifts by 16 often come up during the legalization of EXTRACT_VECTOR_ELT.
6fd11ba
to
c3f53d7
Compare
Apologies, I didn't mean to click the request review button. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change itself LGTM :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
When two 16-bit values are combined into a v2x16 vector, and those values are truncated come from 32-bit values, a PRMT instruction can save registers by selecting bytes directly from the original 32-bit values. We do this during a post-legalize DAG combine, as these opportunities are typically only exposed after the BUILD_VECTOR's operands have been legalized.
Additionally, if the 32-bit values are right-shifted, we can fold in the shift by selecting higher bytes with PRMT. Only logical right-shifts by 16 are supported (for now) since those are the only situations seen in practice. Right shifts by 16 often come up during the legalization of EXTRACT_VECTOR_ELT.
This idea was brought up in a PR comment by @Artem-B.