Skip to content

[MLIR][NVVM] Add support for f32 in redux.sync Op #128137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

Wolfram70
Copy link
Contributor

This change adds support for the f32 variants of the redux.sync instruction in the NVVM Dialect through the newly added intrinsics for the same.

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Srinivasa Ravi (Wolfram70)

Changes

This change adds support for the f32 variants of the redux.sync instruction in the NVVM Dialect through the newly added intrinsics for the same.


Full diff: https://github.com/llvm/llvm-project/pull/128137.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+7-3)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp (+15-2)
  • (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+19)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+43)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0de5a87e72c3f..92bc73b0d03ff 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -257,11 +257,13 @@ def ReduxKindOr   : I32EnumAttrCase<"OR", 5, "or">;
 def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">;
 def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">;
 def ReduxKindXor  : I32EnumAttrCase<"XOR", 8, "xor">; 
+def ReduxKindFmin : I32EnumAttrCase<"FMIN", 9, "fmin">;
+def ReduxKindFmax : I32EnumAttrCase<"FMAX", 10, "fmax">;
 
 /// Enum attribute of the different kinds.
 def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
   [ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr, 
-    ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> {
+   ReduxKindUmax, ReduxKindUmin, ReduxKindXor, ReduxKindFmin, ReduxKindFmax]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::NVVM";
 }
@@ -273,9 +275,11 @@ def NVVM_ReduxOp :
   Results<(outs LLVM_Type:$res)>,
   Arguments<(ins LLVM_Type:$val,
                  ReduxKindAttr:$kind,
-                 I32:$mask_and_clamp)> {
+                 I32:$mask_and_clamp,
+                 DefaultValuedAttr<BoolAttr, "false">:$abs,
+                 DefaultValuedAttr<BoolAttr, "false">:$nan)> {
   string llvmBuilder = [{
-      auto intId = getReduxIntrinsicId($_resultType, $kind);
+      auto intId = getReduxIntrinsicId($_resultType, $kind, $abs, $nan);
       $res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp});
   }];
   let assemblyFormat = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 8b13735774663..721778be6ba20 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -25,9 +25,18 @@ using namespace mlir;
 using namespace mlir::LLVM;
 using mlir::LLVM::detail::createIntrinsicCall;
 
+#define REDUX_F32_ID_IMPL(op, abs, nan)                                        \
+  hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##nan                    \
+         : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
+
+#define GET_REDUX_F32_ID(op, abs, nan)                                         \
+  hasAbs ? REDUX_F32_ID_IMPL(op, abs, nan) : REDUX_F32_ID_IMPL(op, , nan)
+
 static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
-                                               NVVM::ReduxKind kind) {
-  if (!resultType->isIntegerTy(32))
+                                               NVVM::ReduxKind kind,
+                                               bool hasAbs,
+                                               bool hasNaN) {
+  if (!(resultType->isIntegerTy(32) || resultType->isFloatTy()))
     llvm_unreachable("unsupported data type for redux");
 
   switch (kind) {
@@ -47,6 +56,10 @@ static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
     return llvm::Intrinsic::nvvm_redux_sync_max;
   case NVVM::ReduxKind::MIN:
     return llvm::Intrinsic::nvvm_redux_sync_min;
+  case NVVM::ReduxKind::FMIN:
+    return GET_REDUX_F32_ID(min, _abs, _NaN);
+  case NVVM::ReduxKind::FMAX:
+    return GET_REDUX_F32_ID(max, _abs, _NaN);
   }
   llvm_unreachable("unknown redux kind");
 }
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index dd54acd1e317e..85998d4e66254 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -411,6 +411,25 @@ llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
   llvm.return %r1 : i32
 }
 
+llvm.func @redux_sync_f32(%value: f32, %offset: i32) -> f32 {
+  // CHECK: nvvm.redux.sync fmin %{{.*}}
+  %r1 = nvvm.redux.sync fmin %value, %offset: f32 -> f32
+  // CHECK: nvvm.redux.sync fmin %{{.*}}
+  %r2 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32
+  // CHECK: nvvm.redux.sync fmin %{{.*}}
+  %r3 = nvvm.redux.sync fmin %value, %offset {NaN = true}: f32 -> f32
+  // CHECK: nvvm.redux.sync fmin %{{.*}}
+  %r4 = nvvm.redux.sync fmin %value, %offset {abs = true, NaN = true}: f32 -> f32
+  // CHECK: nvvm.redux.sync fmax %{{.*}}
+  %r5 = nvvm.redux.sync fmax %value, %offset: f32 -> f32
+  // CHECK: nvvm.redux.sync fmax %{{.*}}
+  %r6 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32
+  // CHECK: nvvm.redux.sync fmax %{{.*}}
+  %r7 = nvvm.redux.sync fmax %value, %offset {NaN = true}: f32 -> f32
+  // CHECK: nvvm.redux.sync fmax %{{.*}}
+  %r8 = nvvm.redux.sync fmax %value, %offset {abs = true, NaN = true}: f32 -> f32
+  llvm.return %r1 : f32
+}
 
 // -----
 
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 5ab593452ab66..d11558698d860 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -780,3 +780,46 @@ llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
   %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
   llvm.return
 }
+
+// -----
+// CHECK-LABEL: @nvvm_redux_sync
+llvm.func @nvvm_redux_sync(%value: i32, %offset: i32) {
+  // CHECK: call i32 @llvm.nvvm.redux.sync.add(i32 %{{.*}}, i32 %{{.*}})
+  %0 = nvvm.redux.sync add %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.umax(i32 %{{.*}}, i32 %{{.*}})
+  %1 = nvvm.redux.sync umax %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.umin(i32 %{{.*}}, i32 %{{.*}})
+  %2 = nvvm.redux.sync umin %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.and(i32 %{{.*}}, i32 %{{.*}})
+  %3 = nvvm.redux.sync and %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.or(i32 %{{.*}}, i32 %{{.*}})
+  %4 = nvvm.redux.sync or %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.xor(i32 %{{.*}}, i32 %{{.*}})
+  %5 = nvvm.redux.sync xor %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.max(i32 %{{.*}}, i32 %{{.*}})
+  %6 = nvvm.redux.sync max %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.min(i32 %{{.*}}, i32 %{{.*}})
+  %7 = nvvm.redux.sync min %value, %offset: i32 -> i32
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_redux_sync_f32
+llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) {
+  // CHECK: call float @llvm.nvvm.redux.sync.fmin(float %{{.*}}, i32 %{{.*}})
+  %0 = nvvm.redux.sync fmin %value, %offset: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmin.abs(float %{{.*}}, i32 %{{.*}})
+  %1 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmin.NaN(float %{{.*}}, i32 %{{.*}})
+  %2 = nvvm.redux.sync fmin %value, %offset {nan = true}: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmin.abs.NaN(float %{{.*}}, i32 %{{.*}})
+  %3 = nvvm.redux.sync fmin %value, %offset {abs = true, nan = true}: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmax(float %{{.*}}, i32 %{{.*}})
+  %4 = nvvm.redux.sync fmax %value, %offset: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmax.abs(float %{{.*}}, i32 %{{.*}})
+  %5 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmax.NaN(float %{{.*}}, i32 %{{.*}})
+  %6 = nvvm.redux.sync fmax %value, %offset {nan = true}: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmax.abs.NaN(float %{{.*}}, i32 %{{.*}})
+  %7 = nvvm.redux.sync fmax %value, %offset {abs = true, nan = true}: f32 -> f32
+  llvm.return
+}

Copy link

github-actions bot commented Feb 21, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-redux-sync-f32-op branch from fdbf42b to 9e93205 Compare February 21, 2025 05:56
Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, it looks instruction is implemented fully.

I32:$mask_and_clamp)> {
I32:$mask_and_clamp,
DefaultValuedAttr<BoolAttr, "false">:$abs,
DefaultValuedAttr<BoolAttr, "false">:$nan)> {
string llvmBuilder = [{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a doc-string for this op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in the latest revision. Thanks!

This change adds support for the f32 variants of the `redux.sync`
instruction in the NVVM Dialect through the newly added intrinsics for
the same.
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-redux-sync-f32-op branch from 9e93205 to 9f3bf77 Compare February 21, 2025 10:48
@Wolfram70 Wolfram70 merged commit 7cda365 into llvm:main Feb 24, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants