Skip to content

[NVPTX] Support BFloat Store Parameter #137074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 24, 2025
Merged

Conversation

stumpOS
Copy link
Contributor

@stumpOS stumpOS commented Apr 23, 2025

Before this patch, the instruction selector assumed that if the Memory Type is not {f16, v2f16, f32, f64} then the node type must be a ConstantSDNode when in fact if the memory type is bf16 then the node type is ConstantFPSDNode.

@llvmbot
Copy link
Member

llvmbot commented Apr 23, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Steffi Stumpos (stumpOS)

Changes

Before this patch, the instruction selector assumed that if the Memory Type is not {f16, v2f16, f32, f64} then the node type must be a ConstantSDNode when in fact if the memory type is bf16 then the node type is ConstantFPSDNode.


Full diff: https://github.com/llvm/llvm-project/pull/137074.diff

2 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/st-param-imm.ll (+24)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ec1f969494cd1..e74c8828aaf1b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -583,7 +583,7 @@ getOperationOrderings(MemSDNode *N, const NVPTXSubtarget *Subtarget) {
   // |------------------------------------------------------|-------------------------------|
   // | cuda::atomic_load                                    | fence.sc.<scope>;             |
   // |   (memory_order_seq_cst, cuda::thread_scope_<scope>) | ld.acquire.<scope>;           |
-  // |------------------------------------------------------|-------------------------------|  
+  // |------------------------------------------------------|-------------------------------|
   // | cuda::atomic_store                                   | fence.sc.<scope>;             |
   // |   (memory_order_seq_cst, cuda::thread_scope_<scope>) | st.release.<scope>;           |
   // |------------------------------------------------------|-------------------------------|
@@ -1852,7 +1852,7 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
     case 1: {
       MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
       SDValue Imm = Ops[0];
-      if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
+      if (MemTy != MVT::f16 && MemTy != MVT::v2f16 && MemTy != MVT::bf16 &&
           (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
         // Convert immediate to target constant
         if (MemTy == MVT::f32 || MemTy == MVT::f64) {
@@ -2808,8 +2808,8 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkPrefetchL2(SDNode *N) {
   SDLoc DL(N);
   SmallVector<SDValue, 4> Ops(N->ops().slice(2, NumArgs));
   Ops.push_back(N->getOperand(0)); // Chain operand
-  
-  unsigned Opcode = IsCacheHint 
+
+  unsigned Opcode = IsCacheHint
   ?  NVPTX::CP_ASYNC_BULK_PREFETCH_CH
   :  NVPTX::CP_ASYNC_BULK_PREFETCH;
   ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
diff --git a/llvm/test/CodeGen/NVPTX/st-param-imm.ll b/llvm/test/CodeGen/NVPTX/st-param-imm.ll
index ab1447607ab65..d5463b04b3b72 100644
--- a/llvm/test/CodeGen/NVPTX/st-param-imm.ll
+++ b/llvm/test/CodeGen/NVPTX/st-param-imm.ll
@@ -2000,3 +2000,27 @@ declare void @call_v4_i8(%struct.char4 alignstack(4))
 declare void @call_v4_i16(%struct.short4 alignstack(8))
 declare void @call_v4_i32(%struct.int4 alignstack(16))
 declare void @call_v4_f32(%struct.float4 alignstack(16))
+
+define void @st_param_bfloat() {
+; CHECK-LABEL: st_param_bfloat(
+; CHECK: {
+; CHECK-NEXT:	.reg .b16 	%rs<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:// %bb.0:
+; CHECK-NEXT:	mov.b16 	%rs1, 0x4100;
+; CHECK-NEXT:	{ // callseq 83, 0
+; CHECK-NEXT:	.param .align 2 .b8 param0[2];
+; CHECK-NEXT:	st.param.b16 	[param0], %rs1;
+; CHECK-NEXT:	call.uni
+; CHECK-NEXT:	call_bfloat,
+; CHECK-NEXT:	(
+; CHECK-NEXT:	param0
+; CHECK-NEXT:	);
+; CHECK-NEXT:	} // callseq 83
+; CHECK-NEXT:	ret;
+  %five = bitcast i16 16640 to bfloat
+  call void @call_bfloat(bfloat %five)
+  ret void
+}
+
+declare void @call_bfloat(bfloat)

Copy link

github-actions bot commented Apr 23, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@npanchen npanchen requested review from Artem-B and AlexMaclean April 23, 2025 22:40
Comment on lines 2026 to 2047
define void @st_param_v2bfloat(<2 x bfloat> %val) {
; CHECK-LABEL: st_param_v2bfloat(
; CHECK: .param .align 4 .b8 st_param_v2bfloat_param_0[4]
; CHECK-NEXT: )
; CHECK-NEXT: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [st_param_v2bfloat_param_0];
; CHECK-NEXT: { // callseq 84, 0
; CHECK-NEXT: .param .align 4 .b8 param0[4];
; CHECK-NEXT: st.param.b32 [param0], %r1;
; CHECK-NEXT: call.uni
; CHECK-NEXT: call_v2bfloat,
; CHECK-NEXT: (
; CHECK-NEXT: param0
; CHECK-NEXT: );
; CHECK-NEXT: } // callseq 84
; CHECK-NEXT: ret;
call void @call_v2bfloat(<2 x bfloat> %val)
ret void
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like the output of this test has changed (source)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably because ConstantFPSDNode do not have vector types; I will remove this test and the vector check (see Alex's comment)

Copy link
Member

@AlexMaclean AlexMaclean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! A little more cleanup can be performed here.

@@ -1852,7 +1852,8 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
case 1: {
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
SDValue Imm = Ops[0];
if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
if (MemTy != MVT::f16 && MemTy != MVT::v2f16 && MemTy != MVT::bf16 &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can probably just remove the vector types here. A ConstantFPSDNode should never have one of these types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will remove both vector type checks, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate why is it correct to remove them ? v2f16 and v2bf16 here are guarding handling of imms within this if-statement, so removing it will treat these constants as if they're ConstantInt, which should assert, right ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conditional also requires that the node type is either ConstantInt or ConstantFloat but when the vector type is used the node type is neither; I was unable to create a test that generates a constant node with vector type. In the test I added and then removed this code path was not hit because the node type was MemIntrinsicSDNode. This is why the test I added for vector types passed before my change (see Justin's comment above)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argh, yes. Somehow missed that.

@npanchen npanchen merged commit c007c46 into llvm:main Apr 24, 2025
11 checks passed
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
Before this patch, the instruction selector assumed that if the Memory
Type is not {f16, v2f16, f32, f64} then the node type must be a
ConstantSDNode when in fact if the memory type is bf16 then the node
type is ConstantFPSDNode.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
Before this patch, the instruction selector assumed that if the Memory
Type is not {f16, v2f16, f32, f64} then the node type must be a
ConstantSDNode when in fact if the memory type is bf16 then the node
type is ConstantFPSDNode.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
Before this patch, the instruction selector assumed that if the Memory
Type is not {f16, v2f16, f32, f64} then the node type must be a
ConstantSDNode when in fact if the memory type is bf16 then the node
type is ConstantFPSDNode.
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request May 9, 2025
Before this patch, the instruction selector assumed that if the Memory
Type is not {f16, v2f16, f32, f64} then the node type must be a
ConstantSDNode when in fact if the memory type is bf16 then the node
type is ConstantFPSDNode.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants