Skip to content

[DAGCombiner] Fold pattern for srl-shl-zext #138290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10972,6 +10972,22 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
return DAG.getNode(ISD::SRL, DL, VT, N0, NewOp1);
}

// fold (srl (logic_op x, (shl (zext y), c1)), c1)
// -> (logic_op (srl x, c1), (zext y))
// c1 <= leadingzeros(zext(y))
SDValue X, ZExtY;
if (N1C && sd_match(N0, m_OneUse(m_BitwiseLogic(
m_Value(X),
m_OneUse(m_Shl(m_AllOf(m_Value(ZExtY),
m_Opc(ISD::ZERO_EXTEND)),
m_Specific(N1))))))) {
unsigned NumLeadingZeros = ZExtY.getScalarValueSizeInBits() -
ZExtY.getOperand(0).getScalarValueSizeInBits();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N1C && ISD::isBitwiseLogicOp(N0.getOpcode()))

(ideally we'd use sd_match but we're missing m_BitwiseLogic)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. ISD::isBitwiseLogicOp helper is utilized.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since #138301 we now have m_BitwiseLogic if you wanted to use SDPatternMatch to simplify the commutative matching - but this is is optional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I reworded with sd_match. Definetly, it's more concise.

I didn't find builtin functionality to matched node with specific opcode, so I used next construction: m_AllOf(m_Value(ZExtY), m_Opc(ISD::ZERO_EXTEND)). If you know a more elegant solution, please point it out.

if (N1C->getZExtValue() <= NumLeadingZeros)
return DAG.getNode(N0.getOpcode(), SDLoc(N0), VT,
DAG.getNode(ISD::SRL, SDLoc(N0), VT, X, N1), ZExtY);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) remove braces

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've always told people to put the braces in when the body is more thank one physical line, even if it's a single logical line of code (or a single statement). But I'm not sure if that's written down anywhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the coding style just refers to "simple statements" :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm in favor of more braces always

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's not critical point, and can be left as is (no braces). I personally prefere braces, but in this particular cases omitting of braces aligned with coding standards from my point of view.

braces should be used when a single-statement body is complex enough that it becomes difficult to see where the block containing the following statement began.

This statement doesn't seem complex for me.

// fold operands of srl based on knowledge that the low bits are not
// demanded.
if (SimplifyDemandedBits(SDValue(N, 0)))
Expand Down
152 changes: 152 additions & 0 deletions llvm/test/CodeGen/NVPTX/shift-opt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
; RUN: llc < %s -mtriple=nvptx64 | FileCheck %s

define i64 @test_or(i64 %x, i32 %y) {
;
; Fold: srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
; c1 <= leadingzeros(zext(y))
;
; CHECK-LABEL: test_or
; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test_or_param_0];
; CHECK: ld.param.u32 %[[Y:rd[0-9]+]], [test_or_param_1];
; CHECK: shr.u64 %[[SHR:rd[0-9]+]], %[[X]], 5;
; CHECK: or.b64 %[[LOP:rd[0-9]+]], %[[SHR]], %[[Y]];
; CHECK: st.param.b64 [func_retval0], %[[LOP]];
;
%ext = zext i32 %y to i64
%shl = shl i64 %ext, 5
%or = or i64 %x, %shl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code has and and xor but those aren't tested here. Also should test vector cases, and negative tests for multiple uses, and not enough known bits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added tests for:

  • xor and and
  • Vector or
  • Negative. Multiple uses of logic_op and shl

What did you mean by "not enough known bits"? Case of "c1 > leadingzeros(zext(y))" was already covered by the test.

%srl = lshr i64 %or, 5
ret i64 %srl
}

define i64 @test_xor(i64 %x, i32 %y) {
;
; Fold: srl (xor (x, shl(zext(y),c1)),c1) -> xor(srl(x,c1), zext(y))
; c1 <= leadingzeros(zext(y))
;
; CHECK-LABEL: test_xor
; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test_xor_param_0];
; CHECK: ld.param.u32 %[[Y:rd[0-9]+]], [test_xor_param_1];
; CHECK: shr.u64 %[[SHR:rd[0-9]+]], %[[X]], 5;
; CHECK: xor.b64 %[[LOP:rd[0-9]+]], %[[SHR]], %[[Y]];
; CHECK: st.param.b64 [func_retval0], %[[LOP]];
;
%ext = zext i32 %y to i64
%shl = shl i64 %ext, 5
%or = xor i64 %x, %shl
%srl = lshr i64 %or, 5
ret i64 %srl
}

define i64 @test_and(i64 %x, i32 %y) {
;
; Fold: srl (and (x, shl(zext(y),c1)),c1) -> and(srl(x,c1), zext(y))
; c1 <= leadingzeros(zext(y))
;
; CHECK-LABEL: test_and
; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test_and_param_0];
; CHECK: ld.param.u32 %[[Y:rd[0-9]+]], [test_and_param_1];
; CHECK: shr.u64 %[[SHR:rd[0-9]+]], %[[X]], 5;
; CHECK: and.b64 %[[LOP:rd[0-9]+]], %[[SHR]], %[[Y]];
; CHECK: st.param.b64 [func_retval0], %[[LOP]];
;
%ext = zext i32 %y to i64
%shl = shl i64 %ext, 5
%or = and i64 %x, %shl
%srl = lshr i64 %or, 5
ret i64 %srl
}

define <2 x i16> @test_vec(<2 x i16> %x, <2 x i8> %y) {
;
; Fold: srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
; c1 <= leadingzeros(zext(y))
; x, y - vectors
;
; CHECK-LABEL: test_vec
; CHECK: ld.param.u32 %[[X:r[0-9]+]], [test_vec_param_0];
; CHECK: ld.param.u32 %[[P1:r[0-9]+]], [test_vec_param_1];
; CHECK: and.b32 %[[Y:r[0-9]+]], %[[P1]], 16711935;
; CHECK: mov.b32 {%[[X1:rs[0-9]+]], %[[X2:rs[0-9]+]]}, %[[X]];
; CHECK: shr.u16 %[[SHR2:rs[0-9]+]], %[[X2]], 5;
; CHECK: shr.u16 %[[SHR1:rs[0-9]+]], %[[X1]], 5;
; CHECK: mov.b32 %[[SHR:r[0-9]+]], {%[[SHR1]], %[[SHR2]]};
; CHECK: or.b32 %[[LOP:r[0-9]+]], %[[SHR]], %[[Y]];
; CHECK: st.param.b32 [func_retval0], %[[LOP]];
;
%ext = zext <2 x i8> %y to <2 x i16>
%shl = shl <2 x i16> %ext, splat(i16 5)
%or = or <2 x i16> %x, %shl
%srl = lshr <2 x i16> %or, splat(i16 5)
ret <2 x i16> %srl
}

define i64 @test_negative_c(i64 %x, i32 %y) {
;
; Do not fold: srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
; Reason: c1 > leadingzeros(zext(y)).
;
; CHECK-LABEL: test_negative_c
; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test_negative_c_param_0];
; CHECK: ld.param.u32 %[[Y:rd[0-9]+]], [test_negative_c_param_1];
; CHECK: shl.b64 %[[SHL:rd[0-9]+]], %[[Y]], 33;
; CHECK: or.b64 %[[OR:rd[0-9]+]], %[[X]], %[[SHL]];
; CHECK: shr.u64 %[[SHR:rd[0-9]+]], %[[OR]], 33;
; CHECK: st.param.b64 [func_retval0], %[[SHR]];
;
%ext = zext i32 %y to i64
%shl = shl i64 %ext, 33
%or = or i64 %x, %shl
%srl = lshr i64 %or, 33
ret i64 %srl
}

declare void @use(i64)

define i64 @test_negative_use_lop(i64 %x, i32 %y) {
;
; Do not fold: srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
; Reason: multiple usage of "or"
;
; CHECK-LABEL: test_negative_use_lop
; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test_negative_use_lop_param_0];
; CHECK: ld.param.u32 %[[Y:r[0-9]+]], [test_negative_use_lop_param_1];
; CHECK: mul.wide.u32 %[[SHL:rd[0-9]+]], %[[Y]], 32;
; CHECK: or.b64 %[[OR:rd[0-9]+]], %[[X]], %[[SHL]];
; CHECK: shr.u64 %[[SHR:rd[0-9]+]], %[[OR]], 5;
; CHECK: { // callseq
; CHECK: st.param.b64 [param0], %[[OR]];
; CHECK: } // callseq
; CHECK: st.param.b64 [func_retval0], %[[SHR]];
;
%ext = zext i32 %y to i64
%shl = shl i64 %ext, 5
%or = or i64 %x, %shl
%srl = lshr i64 %or, 5
call void @use(i64 %or)
ret i64 %srl
}

define i64 @test_negative_use_shl(i64 %x, i32 %y) {
;
; Do not fold: srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
; Reason: multiple usage of "shl"
;
; CHECK-LABEL: test_negative_use_shl
; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test_negative_use_shl_param_0];
; CHECK: ld.param.u32 %[[Y:r[0-9]+]], [test_negative_use_shl_param_1];
; CHECK: mul.wide.u32 %[[SHL:rd[0-9]+]], %[[Y]], 32;
; CHECK: or.b64 %[[OR:rd[0-9]+]], %[[X]], %[[SHL]];
; CHECK: shr.u64 %[[SHR:rd[0-9]+]], %[[OR]], 5;
; CHECK: { // callseq
; CHECK: st.param.b64 [param0], %[[SHL]];
; CHECK: } // callseq
; CHECK: st.param.b64 [func_retval0], %[[SHR]];
;
%ext = zext i32 %y to i64
%shl = shl i64 %ext, 5
%or = or i64 %x, %shl
%srl = lshr i64 %or, 5
call void @use(i64 %shl)
ret i64 %srl
}
Loading