Skip to content

[RISCV] Lower memory ops and VP splat for zvfhmin and zvfbfmin #109387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 25, 2024

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Sep 20, 2024

We can lower f16/bf16 memory ops without promotion through the existing custom lowering.

Some of the zero strided VP loads get combined to a VP splat, so we need to also handle the lowering for that for f16/bf16 w/ zvfhmin/zvfbfmin. This patch copies the lowering from ISD::SPLAT_VECTOR over to lowerScalarSplat which is used by the VP splat lowering.

@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes

We can lower f16/bf16 memory ops without promotion through the existing custom lowering.

Some of the zero strided VP loads get combined to a VP splat, so we need to also handle the lowering for that for f16/bf16 w/ zvfhmin/zvfbfmin. This patch copies the lowering from ISD::SPLAT_VECTOR over to lowerScalarSplat which is used by the VP splat lowering.


Patch is 104.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109387.diff

12 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+28-5)
  • (modified) llvm/test/CodeGen/RISCV/rvv/masked-load-fp.ll (+70-2)
  • (modified) llvm/test/CodeGen/RISCV/rvv/masked-store-fp.ll (+70-2)
  • (modified) llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll (+212-4)
  • (modified) llvm/test/CodeGen/RISCV/rvv/mscatter-sdnode.ll (+190-4)
  • (modified) llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll (+106-38)
  • (modified) llvm/test/CodeGen/RISCV/rvv/strided-vpstore.ll (+76-12)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vp-splat.ll (+220-32)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vpgather-sdnode.ll (+211-18)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vpload.ll (+72-10)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vpscatter-sdnode.ll (+201-18)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vpstore.ll (+62-10)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c4458b14f36ece..f9dc440b3c4176 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1082,10 +1082,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                          VT, Custom);
       MVT EltVT = VT.getVectorElementType();
       if (isTypeLegal(EltVT))
-        setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
+        setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT}, VT,
+                           Custom);
       else
-        setOperationAction(ISD::SPLAT_VECTOR, EltVT, Custom);
-      setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
+        setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT},
+                           EltVT, Custom);
+      setOperationAction({ISD::LOAD, ISD::STORE, ISD::MLOAD, ISD::MSTORE,
+                          ISD::MGATHER, ISD::MSCATTER, ISD::VP_LOAD,
+                          ISD::VP_STORE, ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
+                          ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER,
+                          ISD::VP_SCATTER},
+                         VT, Custom);
 
       setOperationAction(ISD::FNEG, VT, Expand);
       setOperationAction(ISD::FABS, VT, Expand);
@@ -4449,11 +4456,27 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
   bool HasPassthru = Passthru && !Passthru.isUndef();
   if (!HasPassthru && !Passthru)
     Passthru = DAG.getUNDEF(VT);
-  if (VT.isFloatingPoint())
-    return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, VT, Passthru, Scalar, VL);
 
+  MVT EltVT = VT.getVectorElementType();
   MVT XLenVT = Subtarget.getXLenVT();
 
+  if (VT.isFloatingPoint()) {
+    if ((EltVT == MVT::f16 && !Subtarget.hasStdExtZvfh()) ||
+        EltVT == MVT::bf16) {
+      if ((EltVT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) ||
+          (EltVT == MVT::f16 && Subtarget.hasStdExtZfhmin()))
+        Scalar = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Scalar);
+      else
+        Scalar = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Scalar);
+      MVT IVT = VT.changeVectorElementType(MVT::i16);
+      Passthru = DAG.getNode(ISD::BITCAST, DL, IVT, Passthru);
+      SDValue Splat =
+          lowerScalarSplat(Passthru, Scalar, VL, IVT, DL, DAG, Subtarget);
+      return DAG.getNode(ISD::BITCAST, DL, VT, Splat);
+    }
+    return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, VT, Passthru, Scalar, VL);
+  }
+
   // Simplest case is that the operand needs to be promoted to XLenVT.
   if (Scalar.getValueType().bitsLE(XLenVT)) {
     // If the operand is a constant, sign extend to increase our chances
diff --git a/llvm/test/CodeGen/RISCV/rvv/masked-load-fp.ll b/llvm/test/CodeGen/RISCV/rvv/masked-load-fp.ll
index df1bd889c1042a..9c7ad239bcade3 100644
--- a/llvm/test/CodeGen/RISCV/rvv/masked-load-fp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/masked-load-fp.ll
@@ -1,6 +1,19 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+zvfbfmin,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+zvfbfmin,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+zvfbfmin,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+zvfbfmin,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
+
+define <vscale x 1 x bfloat> @masked_load_nxv1bf16(ptr %a, <vscale x 1 x i1> %mask) nounwind {
+; CHECK-LABEL: masked_load_nxv1bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, mf4, ta, ma
+; CHECK-NEXT:    vle16.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  %load = call <vscale x 1 x bfloat> @llvm.masked.load.nxv1bf16(ptr %a, i32 2, <vscale x 1 x i1> %mask, <vscale x 1 x bfloat> undef)
+  ret <vscale x 1 x bfloat> %load
+}
+declare <vscale x 1 x bfloat> @llvm.masked.load.nxv1bf16(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x bfloat>)
 
 define <vscale x 1 x half> @masked_load_nxv1f16(ptr %a, <vscale x 1 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_load_nxv1f16:
@@ -35,6 +48,17 @@ define <vscale x 1 x double> @masked_load_nxv1f64(ptr %a, <vscale x 1 x i1> %mas
 }
 declare <vscale x 1 x double> @llvm.masked.load.nxv1f64(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x double>)
 
+define <vscale x 2 x bfloat> @masked_load_nxv2bf16(ptr %a, <vscale x 2 x i1> %mask) nounwind {
+; CHECK-LABEL: masked_load_nxv2bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vle16.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  %load = call <vscale x 2 x bfloat> @llvm.masked.load.nxv2bf16(ptr %a, i32 2, <vscale x 2 x i1> %mask, <vscale x 2 x bfloat> undef)
+  ret <vscale x 2 x bfloat> %load
+}
+declare <vscale x 2 x bfloat> @llvm.masked.load.nxv2bf16(ptr, i32, <vscale x 2 x i1>, <vscale x 2 x bfloat>)
+
 define <vscale x 2 x half> @masked_load_nxv2f16(ptr %a, <vscale x 2 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_load_nxv2f16:
 ; CHECK:       # %bb.0:
@@ -68,6 +92,17 @@ define <vscale x 2 x double> @masked_load_nxv2f64(ptr %a, <vscale x 2 x i1> %mas
 }
 declare <vscale x 2 x double> @llvm.masked.load.nxv2f64(ptr, i32, <vscale x 2 x i1>, <vscale x 2 x double>)
 
+define <vscale x 4 x bfloat> @masked_load_nxv4bf16(ptr %a, <vscale x 4 x i1> %mask) nounwind {
+; CHECK-LABEL: masked_load_nxv4bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, m1, ta, ma
+; CHECK-NEXT:    vle16.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  %load = call <vscale x 4 x bfloat> @llvm.masked.load.nxv4bf16(ptr %a, i32 2, <vscale x 4 x i1> %mask, <vscale x 4 x bfloat> undef)
+  ret <vscale x 4 x bfloat> %load
+}
+declare <vscale x 4 x bfloat> @llvm.masked.load.nxv4bf16(ptr, i32, <vscale x 4 x i1>, <vscale x 4 x bfloat>)
+
 define <vscale x 4 x half> @masked_load_nxv4f16(ptr %a, <vscale x 4 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_load_nxv4f16:
 ; CHECK:       # %bb.0:
@@ -101,6 +136,17 @@ define <vscale x 4 x double> @masked_load_nxv4f64(ptr %a, <vscale x 4 x i1> %mas
 }
 declare <vscale x 4 x double> @llvm.masked.load.nxv4f64(ptr, i32, <vscale x 4 x i1>, <vscale x 4 x double>)
 
+define <vscale x 8 x bfloat> @masked_load_nxv8bf16(ptr %a, <vscale x 8 x i1> %mask) nounwind {
+; CHECK-LABEL: masked_load_nxv8bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vle16.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  %load = call <vscale x 8 x bfloat> @llvm.masked.load.nxv8bf16(ptr %a, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x bfloat> undef)
+  ret <vscale x 8 x bfloat> %load
+}
+declare <vscale x 8 x bfloat> @llvm.masked.load.nxv8bf16(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x bfloat>)
+
 define <vscale x 8 x half> @masked_load_nxv8f16(ptr %a, <vscale x 8 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_load_nxv8f16:
 ; CHECK:       # %bb.0:
@@ -134,6 +180,17 @@ define <vscale x 8 x double> @masked_load_nxv8f64(ptr %a, <vscale x 8 x i1> %mas
 }
 declare <vscale x 8 x double> @llvm.masked.load.nxv8f64(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x double>)
 
+define <vscale x 16 x bfloat> @masked_load_nxv16bf16(ptr %a, <vscale x 16 x i1> %mask) nounwind {
+; CHECK-LABEL: masked_load_nxv16bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, m4, ta, ma
+; CHECK-NEXT:    vle16.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  %load = call <vscale x 16 x bfloat> @llvm.masked.load.nxv16bf16(ptr %a, i32 2, <vscale x 16 x i1> %mask, <vscale x 16 x bfloat> undef)
+  ret <vscale x 16 x bfloat> %load
+}
+declare <vscale x 16 x bfloat> @llvm.masked.load.nxv16bf16(ptr, i32, <vscale x 16 x i1>, <vscale x 16 x bfloat>)
+
 define <vscale x 16 x half> @masked_load_nxv16f16(ptr %a, <vscale x 16 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_load_nxv16f16:
 ; CHECK:       # %bb.0:
@@ -156,6 +213,17 @@ define <vscale x 16 x float> @masked_load_nxv16f32(ptr %a, <vscale x 16 x i1> %m
 }
 declare <vscale x 16 x float> @llvm.masked.load.nxv16f32(ptr, i32, <vscale x 16 x i1>, <vscale x 16 x float>)
 
+define <vscale x 32 x bfloat> @masked_load_nxv32bf16(ptr %a, <vscale x 32 x i1> %mask) nounwind {
+; CHECK-LABEL: masked_load_nxv32bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, m8, ta, ma
+; CHECK-NEXT:    vle16.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  %load = call <vscale x 32 x bfloat> @llvm.masked.load.nxv32bf16(ptr %a, i32 2, <vscale x 32 x i1> %mask, <vscale x 32 x bfloat> undef)
+  ret <vscale x 32 x bfloat> %load
+}
+declare <vscale x 32 x bfloat> @llvm.masked.load.nxv32bf16(ptr, i32, <vscale x 32 x i1>, <vscale x 32 x bfloat>)
+
 define <vscale x 32 x half> @masked_load_nxv32f16(ptr %a, <vscale x 32 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_load_nxv32f16:
 ; CHECK:       # %bb.0:
diff --git a/llvm/test/CodeGen/RISCV/rvv/masked-store-fp.ll b/llvm/test/CodeGen/RISCV/rvv/masked-store-fp.ll
index 17193aef1dff9e..ddb56e0d979a1c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/masked-store-fp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/masked-store-fp.ll
@@ -1,6 +1,19 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+zvfbfmin,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+zvfbfmin,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+zvfbfmin,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+zvfbfmin,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
+
+define void @masked_store_nxv1bf16(<vscale x 1 x bfloat> %val, ptr %a, <vscale x 1 x i1> %mask) nounwind {
+; CHECK-LABEL: masked_store_nxv1bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, mf4, ta, ma
+; CHECK-NEXT:    vse16.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1bf16.p0(<vscale x 1 x bfloat> %val, ptr %a, i32 2, <vscale x 1 x i1> %mask)
+  ret void
+}
+declare void @llvm.masked.store.nxv1bf16.p0(<vscale x 1 x bfloat>, ptr, i32, <vscale x 1 x i1>)
 
 define void @masked_store_nxv1f16(<vscale x 1 x half> %val, ptr %a, <vscale x 1 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_store_nxv1f16:
@@ -35,6 +48,17 @@ define void @masked_store_nxv1f64(<vscale x 1 x double> %val, ptr %a, <vscale x
 }
 declare void @llvm.masked.store.nxv1f64.p0(<vscale x 1 x double>, ptr, i32, <vscale x 1 x i1>)
 
+define void @masked_store_nxv2bf16(<vscale x 2 x bfloat> %val, ptr %a, <vscale x 2 x i1> %mask) nounwind {
+; CHECK-LABEL: masked_store_nxv2bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vse16.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv2bf16.p0(<vscale x 2 x bfloat> %val, ptr %a, i32 2, <vscale x 2 x i1> %mask)
+  ret void
+}
+declare void @llvm.masked.store.nxv2bf16.p0(<vscale x 2 x bfloat>, ptr, i32, <vscale x 2 x i1>)
+
 define void @masked_store_nxv2f16(<vscale x 2 x half> %val, ptr %a, <vscale x 2 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_store_nxv2f16:
 ; CHECK:       # %bb.0:
@@ -68,6 +92,17 @@ define void @masked_store_nxv2f64(<vscale x 2 x double> %val, ptr %a, <vscale x
 }
 declare void @llvm.masked.store.nxv2f64.p0(<vscale x 2 x double>, ptr, i32, <vscale x 2 x i1>)
 
+define void @masked_store_nxv4bf16(<vscale x 4 x bfloat> %val, ptr %a, <vscale x 4 x i1> %mask) nounwind {
+; CHECK-LABEL: masked_store_nxv4bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, m1, ta, ma
+; CHECK-NEXT:    vse16.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv4bf16.p0(<vscale x 4 x bfloat> %val, ptr %a, i32 2, <vscale x 4 x i1> %mask)
+  ret void
+}
+declare void @llvm.masked.store.nxv4bf16.p0(<vscale x 4 x bfloat>, ptr, i32, <vscale x 4 x i1>)
+
 define void @masked_store_nxv4f16(<vscale x 4 x half> %val, ptr %a, <vscale x 4 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_store_nxv4f16:
 ; CHECK:       # %bb.0:
@@ -101,6 +136,17 @@ define void @masked_store_nxv4f64(<vscale x 4 x double> %val, ptr %a, <vscale x
 }
 declare void @llvm.masked.store.nxv4f64.p0(<vscale x 4 x double>, ptr, i32, <vscale x 4 x i1>)
 
+define void @masked_store_nxv8bf16(<vscale x 8 x bfloat> %val, ptr %a, <vscale x 8 x i1> %mask) nounwind {
+; CHECK-LABEL: masked_store_nxv8bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vse16.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv8bf16.p0(<vscale x 8 x bfloat> %val, ptr %a, i32 2, <vscale x 8 x i1> %mask)
+  ret void
+}
+declare void @llvm.masked.store.nxv8bf16.p0(<vscale x 8 x bfloat>, ptr, i32, <vscale x 8 x i1>)
+
 define void @masked_store_nxv8f16(<vscale x 8 x half> %val, ptr %a, <vscale x 8 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_store_nxv8f16:
 ; CHECK:       # %bb.0:
@@ -134,6 +180,17 @@ define void @masked_store_nxv8f64(<vscale x 8 x double> %val, ptr %a, <vscale x
 }
 declare void @llvm.masked.store.nxv8f64.p0(<vscale x 8 x double>, ptr, i32, <vscale x 8 x i1>)
 
+define void @masked_store_nxv16bf16(<vscale x 16 x bfloat> %val, ptr %a, <vscale x 16 x i1> %mask) nounwind {
+; CHECK-LABEL: masked_store_nxv16bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, m4, ta, ma
+; CHECK-NEXT:    vse16.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv16bf16.p0(<vscale x 16 x bfloat> %val, ptr %a, i32 2, <vscale x 16 x i1> %mask)
+  ret void
+}
+declare void @llvm.masked.store.nxv16bf16.p0(<vscale x 16 x bfloat>, ptr, i32, <vscale x 16 x i1>)
+
 define void @masked_store_nxv16f16(<vscale x 16 x half> %val, ptr %a, <vscale x 16 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_store_nxv16f16:
 ; CHECK:       # %bb.0:
@@ -156,6 +213,17 @@ define void @masked_store_nxv16f32(<vscale x 16 x float> %val, ptr %a, <vscale x
 }
 declare void @llvm.masked.store.nxv16f32.p0(<vscale x 16 x float>, ptr, i32, <vscale x 16 x i1>)
 
+define void @masked_store_nxv32bf16(<vscale x 32 x bfloat> %val, ptr %a, <vscale x 32 x i1> %mask) nounwind {
+; CHECK-LABEL: masked_store_nxv32bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, m8, ta, ma
+; CHECK-NEXT:    vse16.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv32bf16.p0(<vscale x 32 x bfloat> %val, ptr %a, i32 2, <vscale x 32 x i1> %mask)
+  ret void
+}
+declare void @llvm.masked.store.nxv32bf16.p0(<vscale x 32 x bfloat>, ptr, i32, <vscale x 32 x i1>)
+
 define void @masked_store_nxv32f16(<vscale x 32 x half> %val, ptr %a, <vscale x 32 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_store_nxv32f16:
 ; CHECK:       # %bb.0:
diff --git a/llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll
index be37be06f0e779..189ba08dddc7aa 100644
--- a/llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll
@@ -1,8 +1,16 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+zvfh,+v -target-abi=ilp32d \
-; RUN:     -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV32
-; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+zvfh,+v -target-abi=lp64d \
-; RUN:     -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64
+; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+zvfh,+zvfbfmin,+v \
+; RUN:     -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s \
+; RUN:     --check-prefixes=CHECK,RV32
+; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+zvfh,+zvfbfmin,+v \
+; RUN:     -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s \
+; RUN:     --check-prefixes=CHECK,RV64
+; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+zvfhmin,+zvfbfmin,+v \
+; RUN:     -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s \
+; RUN:     --check-prefixes=CHECK,RV32
+; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+zvfhmin,+zvfbfmin,+v \
+; RUN:     -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s \
+; RUN:     --check-prefixes=CHECK,RV64
 
 declare <vscale x 1 x i8> @llvm.masked.gather.nxv1i8.nxv1p0(<vscale x 1 x ptr>, i32, <vscale x 1 x i1>, <vscale x 1 x i8>)
 
@@ -1257,6 +1265,206 @@ define void @mgather_nxv16i64(<vscale x 8 x ptr> %ptrs0, <vscale x 8 x ptr> %ptr
   ret void
 }
 
+declare <vscale x 1 x bfloat> @llvm.masked.gather.nxv1bf16.nxv1p0(<vscale x 1 x ptr>, i32, <vscale x 1 x i1>, <vscale x 1 x bfloat>)
+
+define <vscale x 1 x bfloat> @mgather_nxv1bf16(<vscale x 1 x ptr> %ptrs, <vscale x 1 x i1> %m, <vscale x 1 x bfloat> %passthru) {
+; RV32-LABEL: mgather_nxv1bf16:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vsetvli a0, zero, e16, mf4, ta, mu
+; RV32-NEXT:    vluxei32.v v9, (zero), v8, v0.t
+; RV32-NEXT:    vmv1r.v v8, v9
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mgather_nxv1bf16:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetvli a0, zero, e16, mf4, ta, mu
+; RV64-NEXT:    vluxei64.v v9, (zero), v8, v0.t
+; RV64-NEXT:    vmv1r.v v8, v9
+; RV64-NEXT:    ret
+  %v = call <vscale x 1 x bfloat> @llvm.masked.gather.nxv1bf16.nxv1p0(<vscale x 1 x ptr> %ptrs, i32 2, <vscale x 1 x i1> %m, <vscale x 1 x bfloat> %passthru)
+  ret <vscale x 1 x bfloat> %v
+}
+
+declare <vscale x 2 x bfloat> @llvm.masked.gather.nxv2bf16.nxv2p0(<vscale x 2 x ptr>, i32, <vscale x 2 x i1>, <vscale x 2 x bfloat>)
+
+define <vscale x 2 x bfloat> @mgather_nxv2bf16(<vscale x 2 x ptr> %ptrs, <vscale x 2 x i1> %m, <vscale x 2 x bfloat> %passthru) {
+; RV32-LABEL: mgather_nxv2bf16:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vsetvli a0, zero, e16, mf2, ta, mu
+; RV32-NEXT:    vluxei32.v v9, (zero), v8, v0.t
+; RV32-NEXT:    vmv1r.v v8, v9
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mgather_nxv2bf16:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetvli a0, zero, e16, mf2, ta, mu
+; RV64-NEXT:    vluxei64.v v10, (zero), v8, v0.t
+; RV64-NEXT:    vmv1r.v v8, v10
+; RV64-NEXT:    ret
+  %v = call <vscale x 2 x bfloat> @llvm.masked.gather.nxv2bf16.nxv2p0(<vscale x 2 x ptr> %ptrs, i32 2, <vscale x 2 x i1> %m, <vscale x 2 x bfloat> %passthru)
+  ret <vscale x 2 x bfloat> %v
+}
+
+declare <vscale x 4 x bfloat> @llvm.masked.gather.nxv4bf16.nxv4p0(<vscale x 4 x ptr>, i32, <vscale x 4 x i1>, <vscale x 4 x bfloat>)
+
+define <vscale x 4 x bfloat> @mgather_nxv4bf16(<vscale x 4 x ptr> %ptrs, <vscale x 4 x i1> %m, <vscale x 4 x bfloat> %passthru) {
+; RV32-LABEL: mgather_nxv4bf16:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vsetvli a0, zero, e16, m1, ta, mu
+; RV32-NEXT:    vluxei32.v v10, (zero), v8, v0.t
+; RV32-NEXT:    vmv.v.v v8, v10
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mgather_nxv4bf16:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetvli a0, zero, e16, m1, ta, mu
+; RV64-NEXT:    vluxei64.v v12, (zero), v8, v0.t
+; RV64-NEXT:    vmv.v.v v8, v12
+; RV64-NEXT:    ret
+  %v = call <vscale x 4 x bfloat> @llvm.masked.gather.nxv4bf16.nxv4p0(<vscale x 4 x ptr> %ptrs, i32 2, <vscale x 4 x i1> %m, <vscale x 4 x bfloat> %passthru)
+  ret <vscale x 4 x bfloat> %v
+}
+
+define <vscale x 4 x bfloat> @mgather_truemask_nxv4bf16(<vscale x 4 x ptr> %ptrs, <vscale x 4 x bfloat> %passthru) {
+; RV32-LABEL: mgather_truemask_nxv4bf16:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
+; RV32-NEXT:    vluxei32.v v10, (zero), v8
+; RV32-NEXT:    vmv.v.v v8, v10
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mgather_truemask_nxv4bf16:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
+; RV64-NEXT:    vluxei64.v v12, (zero), v8
+; RV64-NEXT:    vmv.v.v v8, v12
+; RV64-NEXT:    ret
+  %v = call <vscale x 4 x bfloat> @llvm.masked.gather.nxv4bf16.nxv4p0(<vscale x 4 x ptr> %ptrs, i32 2, <vscale x 4 x i1> splat (i1 1), <vscale x 4 x bfloat> %passthru)
+  ret <vscale x 4 x bfloat> %v
+}
+
+define...
[truncated]

@@ -807,18 +887,6 @@ define <vscale x 1 x i8> @zero_strided_unmasked_vpload_nxv1i8_i8(ptr %ptr) {

; Test unmasked float zero strided
define <vscale x 1 x half> @zero_strided_unmasked_vpload_nxv1f16(ptr %ptr) {
; CHECK-OPT-LABEL: zero_strided_unmasked_vpload_nxv1f16:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conflicting prefix here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woops, good catch. It's ugly but I've crammed in some more check prefixes. Maybe we can refactor these +optimized-zero-stride-load tests into two separate tests, one with a function attribute that adds the feature and one without.

@lukel97 lukel97 force-pushed the zvfhmin-zvfbfmin-custom-lower-memory-ops branch from d373bcc to b7b74cd Compare September 24, 2024 18:27
Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

else
setOperationAction(ISD::SPLAT_VECTOR, EltVT, Custom);
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT},
Copy link
Collaborator

@topperc topperc Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part tested? Specifically the scalar element type not being legal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a commit to this PR to test the scalar-not-legal path for experimental_vp_splat, and added tests for the existing splat_vector lowering in 9583215

; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, ma
; CHECK-NEXT: vmv.v.x v8, a1
; CHECK-NEXT: ret
; NOZFMIN-LABEL: vp_splat_nxv1bf16:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing an H in the check prefix name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote ZFMIN since it has both zfhmin and zfbfmin. Would it be better to be explicit and use ZFHMINZFBFMIN?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case. I don't have a problemw with keeping it as ZFMIN

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM other than using ZFMIN instead of ZFHMIN in the check-prefixes

@topperc
Copy link
Collaborator

topperc commented Sep 25, 2024

#109900 will affect the tests in this patch so one of us will need to rebase.

@lukel97
Copy link
Contributor Author

lukel97 commented Sep 25, 2024

#109900 will affect the tests in this patch so one of us will need to rebase.

I don't mind rebasing so feel free to land that first

@topperc
Copy link
Collaborator

topperc commented Sep 25, 2024

#109900 will affect the tests in this patch so one of us will need to rebase.

I don't mind rebasing so feel free to land that first

Done.

We can lower f16/bf16 memory ops without promotion through the existing custom lowering.

Some of the zero strided VP loads get combined to a VP splat, so we need to also handle the lowering for that for f16/bf16 w/ zvfhmin/zvfbfmin. This patch copies the lowering from ISD::SPLAT_VECTOR over to lowerScalarSplat which is used by the VP splat lowering.
@lukel97 lukel97 force-pushed the zvfhmin-zvfbfmin-custom-lower-memory-ops branch from 62ea524 to 4553da1 Compare September 25, 2024 17:46
@lukel97 lukel97 merged commit f172c31 into llvm:main Sep 25, 2024
5 of 7 checks passed
lukel97 added a commit to lukel97/llvm-project that referenced this pull request Nov 7, 2024
This is also split off from the zvfhmin/zvfbfmin isLegalElementTypeForRVV work.

Enabling this will cause SLP and RISCVGatherScatterLowering to emit @llvm.experimental.vp.strided.{load,store} intrinsics, and support for this was added in llvm#109387 and llvm#114750.
lukel97 added a commit that referenced this pull request Nov 7, 2024
…115264)

This is also split off from the zvfhmin/zvfbfmin
isLegalElementTypeForRVV work.

Enabling this will cause SLP and RISCVGatherScatterLowering to emit
@llvm.experimental.vp.strided.{load,store} intrinsics, and codegen
support for this was added in #109387 and #114750.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants