-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NVPTX] Support BFloat Store Parameter #137074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Steffi Stumpos (stumpOS) ChangesBefore this patch, the instruction selector assumed that if the Memory Type is not {f16, v2f16, f32, f64} then the node type must be a ConstantSDNode when in fact if the memory type is bf16 then the node type is ConstantFPSDNode. Full diff: https://github.com/llvm/llvm-project/pull/137074.diff 2 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ec1f969494cd1..e74c8828aaf1b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -583,7 +583,7 @@ getOperationOrderings(MemSDNode *N, const NVPTXSubtarget *Subtarget) {
// |------------------------------------------------------|-------------------------------|
// | cuda::atomic_load | fence.sc.<scope>; |
// | (memory_order_seq_cst, cuda::thread_scope_<scope>) | ld.acquire.<scope>; |
- // |------------------------------------------------------|-------------------------------|
+ // |------------------------------------------------------|-------------------------------|
// | cuda::atomic_store | fence.sc.<scope>; |
// | (memory_order_seq_cst, cuda::thread_scope_<scope>) | st.release.<scope>; |
// |------------------------------------------------------|-------------------------------|
@@ -1852,7 +1852,7 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
case 1: {
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
SDValue Imm = Ops[0];
- if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
+ if (MemTy != MVT::f16 && MemTy != MVT::v2f16 && MemTy != MVT::bf16 &&
(isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
// Convert immediate to target constant
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
@@ -2808,8 +2808,8 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkPrefetchL2(SDNode *N) {
SDLoc DL(N);
SmallVector<SDValue, 4> Ops(N->ops().slice(2, NumArgs));
Ops.push_back(N->getOperand(0)); // Chain operand
-
- unsigned Opcode = IsCacheHint
+
+ unsigned Opcode = IsCacheHint
? NVPTX::CP_ASYNC_BULK_PREFETCH_CH
: NVPTX::CP_ASYNC_BULK_PREFETCH;
ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
diff --git a/llvm/test/CodeGen/NVPTX/st-param-imm.ll b/llvm/test/CodeGen/NVPTX/st-param-imm.ll
index ab1447607ab65..d5463b04b3b72 100644
--- a/llvm/test/CodeGen/NVPTX/st-param-imm.ll
+++ b/llvm/test/CodeGen/NVPTX/st-param-imm.ll
@@ -2000,3 +2000,27 @@ declare void @call_v4_i8(%struct.char4 alignstack(4))
declare void @call_v4_i16(%struct.short4 alignstack(8))
declare void @call_v4_i32(%struct.int4 alignstack(16))
declare void @call_v4_f32(%struct.float4 alignstack(16))
+
+define void @st_param_bfloat() {
+; CHECK-LABEL: st_param_bfloat(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:// %bb.0:
+; CHECK-NEXT: mov.b16 %rs1, 0x4100;
+; CHECK-NEXT: { // callseq 83, 0
+; CHECK-NEXT: .param .align 2 .b8 param0[2];
+; CHECK-NEXT: st.param.b16 [param0], %rs1;
+; CHECK-NEXT: call.uni
+; CHECK-NEXT: call_bfloat,
+; CHECK-NEXT: (
+; CHECK-NEXT: param0
+; CHECK-NEXT: );
+; CHECK-NEXT: } // callseq 83
+; CHECK-NEXT: ret;
+ %five = bitcast i16 16640 to bfloat
+ call void @call_bfloat(bfloat %five)
+ ret void
+}
+
+declare void @call_bfloat(bfloat)
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
define void @st_param_v2bfloat(<2 x bfloat> %val) { | ||
; CHECK-LABEL: st_param_v2bfloat( | ||
; CHECK: .param .align 4 .b8 st_param_v2bfloat_param_0[4] | ||
; CHECK-NEXT: ) | ||
; CHECK-NEXT: { | ||
; CHECK-NEXT: .reg .b32 %r<2>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: ld.param.b32 %r1, [st_param_v2bfloat_param_0]; | ||
; CHECK-NEXT: { // callseq 84, 0 | ||
; CHECK-NEXT: .param .align 4 .b8 param0[4]; | ||
; CHECK-NEXT: st.param.b32 [param0], %r1; | ||
; CHECK-NEXT: call.uni | ||
; CHECK-NEXT: call_v2bfloat, | ||
; CHECK-NEXT: ( | ||
; CHECK-NEXT: param0 | ||
; CHECK-NEXT: ); | ||
; CHECK-NEXT: } // callseq 84 | ||
; CHECK-NEXT: ret; | ||
call void @call_v2bfloat(<2 x bfloat> %val) | ||
ret void | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't look like the output of this test has changed (source)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably because ConstantFPSDNode do not have vector types; I will remove this test and the vector check (see Alex's comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! A little more cleanup can be performed here.
@@ -1852,7 +1852,8 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) { | |||
case 1: { | |||
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy; | |||
SDValue Imm = Ops[0]; | |||
if (MemTy != MVT::f16 && MemTy != MVT::v2f16 && | |||
if (MemTy != MVT::f16 && MemTy != MVT::v2f16 && MemTy != MVT::bf16 && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can probably just remove the vector types here. A ConstantFPSDNode should never have one of these types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I will remove both vector type checks, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate why is it correct to remove them ? v2f16
and v2bf16
here are guarding handling of imms within this if-statement, so removing it will treat these constants as if they're ConstantInt
, which should assert, right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conditional also requires that the node type is either ConstantInt or ConstantFloat but when the vector type is used the node type is neither; I was unable to create a test that generates a constant node with vector type. In the test I added and then removed this code path was not hit because the node type was MemIntrinsicSDNode. This is why the test I added for vector types passed before my change (see Justin's comment above)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
argh, yes. Somehow missed that.
1a9f4c5
to
b85acb2
Compare
Before this patch, the instruction selector assumed that if the Memory Type is not {f16, v2f16, f32, f64} then the node type must be a ConstantSDNode when in fact if the memory type is bf16 then the node type is ConstantFPSDNode.
Before this patch, the instruction selector assumed that if the Memory Type is not {f16, v2f16, f32, f64} then the node type must be a ConstantSDNode when in fact if the memory type is bf16 then the node type is ConstantFPSDNode.
Before this patch, the instruction selector assumed that if the Memory Type is not {f16, v2f16, f32, f64} then the node type must be a ConstantSDNode when in fact if the memory type is bf16 then the node type is ConstantFPSDNode.
Before this patch, the instruction selector assumed that if the Memory Type is not {f16, v2f16, f32, f64} then the node type must be a ConstantSDNode when in fact if the memory type is bf16 then the node type is ConstantFPSDNode.
Before this patch, the instruction selector assumed that if the Memory Type is not {f16, v2f16, f32, f64} then the node type must be a ConstantSDNode when in fact if the memory type is bf16 then the node type is ConstantFPSDNode.