-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][NVVM] Add support for f32 in redux.sync Op #128137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Add support for f32 in redux.sync Op #128137
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds support for the f32 variants of the Full diff: https://github.com/llvm/llvm-project/pull/128137.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0de5a87e72c3f..92bc73b0d03ff 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -257,11 +257,13 @@ def ReduxKindOr : I32EnumAttrCase<"OR", 5, "or">;
def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">;
def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">;
def ReduxKindXor : I32EnumAttrCase<"XOR", 8, "xor">;
+def ReduxKindFmin : I32EnumAttrCase<"FMIN", 9, "fmin">;
+def ReduxKindFmax : I32EnumAttrCase<"FMAX", 10, "fmax">;
/// Enum attribute of the different kinds.
def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
[ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr,
- ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> {
+ ReduxKindUmax, ReduxKindUmin, ReduxKindXor, ReduxKindFmin, ReduxKindFmax]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
@@ -273,9 +275,11 @@ def NVVM_ReduxOp :
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_Type:$val,
ReduxKindAttr:$kind,
- I32:$mask_and_clamp)> {
+ I32:$mask_and_clamp,
+ DefaultValuedAttr<BoolAttr, "false">:$abs,
+ DefaultValuedAttr<BoolAttr, "false">:$nan)> {
string llvmBuilder = [{
- auto intId = getReduxIntrinsicId($_resultType, $kind);
+ auto intId = getReduxIntrinsicId($_resultType, $kind, $abs, $nan);
$res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp});
}];
let assemblyFormat = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 8b13735774663..721778be6ba20 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -25,9 +25,18 @@ using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::detail::createIntrinsicCall;
+#define REDUX_F32_ID_IMPL(op, abs, nan) \
+ hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##nan \
+ : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
+
+#define GET_REDUX_F32_ID(op, abs, nan) \
+ hasAbs ? REDUX_F32_ID_IMPL(op, abs, nan) : REDUX_F32_ID_IMPL(op, , nan)
+
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
- NVVM::ReduxKind kind) {
- if (!resultType->isIntegerTy(32))
+ NVVM::ReduxKind kind,
+ bool hasAbs,
+ bool hasNaN) {
+ if (!(resultType->isIntegerTy(32) || resultType->isFloatTy()))
llvm_unreachable("unsupported data type for redux");
switch (kind) {
@@ -47,6 +56,10 @@ static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
return llvm::Intrinsic::nvvm_redux_sync_max;
case NVVM::ReduxKind::MIN:
return llvm::Intrinsic::nvvm_redux_sync_min;
+ case NVVM::ReduxKind::FMIN:
+ return GET_REDUX_F32_ID(min, _abs, _NaN);
+ case NVVM::ReduxKind::FMAX:
+ return GET_REDUX_F32_ID(max, _abs, _NaN);
}
llvm_unreachable("unknown redux kind");
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index dd54acd1e317e..85998d4e66254 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -411,6 +411,25 @@ llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
llvm.return %r1 : i32
}
+llvm.func @redux_sync_f32(%value: f32, %offset: i32) -> f32 {
+ // CHECK: nvvm.redux.sync fmin %{{.*}}
+ %r1 = nvvm.redux.sync fmin %value, %offset: f32 -> f32
+ // CHECK: nvvm.redux.sync fmin %{{.*}}
+ %r2 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32
+ // CHECK: nvvm.redux.sync fmin %{{.*}}
+ %r3 = nvvm.redux.sync fmin %value, %offset {NaN = true}: f32 -> f32
+ // CHECK: nvvm.redux.sync fmin %{{.*}}
+ %r4 = nvvm.redux.sync fmin %value, %offset {abs = true, NaN = true}: f32 -> f32
+ // CHECK: nvvm.redux.sync fmax %{{.*}}
+ %r5 = nvvm.redux.sync fmax %value, %offset: f32 -> f32
+ // CHECK: nvvm.redux.sync fmax %{{.*}}
+ %r6 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32
+ // CHECK: nvvm.redux.sync fmax %{{.*}}
+ %r7 = nvvm.redux.sync fmax %value, %offset {NaN = true}: f32 -> f32
+ // CHECK: nvvm.redux.sync fmax %{{.*}}
+ %r8 = nvvm.redux.sync fmax %value, %offset {abs = true, NaN = true}: f32 -> f32
+ llvm.return %r1 : f32
+}
// -----
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 5ab593452ab66..d11558698d860 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -780,3 +780,46 @@ llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
llvm.return
}
+
+// -----
+// CHECK-LABEL: @nvvm_redux_sync
+llvm.func @nvvm_redux_sync(%value: i32, %offset: i32) {
+ // CHECK: call i32 @llvm.nvvm.redux.sync.add(i32 %{{.*}}, i32 %{{.*}})
+ %0 = nvvm.redux.sync add %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.umax(i32 %{{.*}}, i32 %{{.*}})
+ %1 = nvvm.redux.sync umax %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.umin(i32 %{{.*}}, i32 %{{.*}})
+ %2 = nvvm.redux.sync umin %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.and(i32 %{{.*}}, i32 %{{.*}})
+ %3 = nvvm.redux.sync and %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.or(i32 %{{.*}}, i32 %{{.*}})
+ %4 = nvvm.redux.sync or %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.xor(i32 %{{.*}}, i32 %{{.*}})
+ %5 = nvvm.redux.sync xor %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.max(i32 %{{.*}}, i32 %{{.*}})
+ %6 = nvvm.redux.sync max %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.min(i32 %{{.*}}, i32 %{{.*}})
+ %7 = nvvm.redux.sync min %value, %offset: i32 -> i32
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_redux_sync_f32
+llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) {
+ // CHECK: call float @llvm.nvvm.redux.sync.fmin(float %{{.*}}, i32 %{{.*}})
+ %0 = nvvm.redux.sync fmin %value, %offset: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmin.abs(float %{{.*}}, i32 %{{.*}})
+ %1 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmin.NaN(float %{{.*}}, i32 %{{.*}})
+ %2 = nvvm.redux.sync fmin %value, %offset {nan = true}: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmin.abs.NaN(float %{{.*}}, i32 %{{.*}})
+ %3 = nvvm.redux.sync fmin %value, %offset {abs = true, nan = true}: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmax(float %{{.*}}, i32 %{{.*}})
+ %4 = nvvm.redux.sync fmax %value, %offset: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmax.abs(float %{{.*}}, i32 %{{.*}})
+ %5 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmax.NaN(float %{{.*}}, i32 %{{.*}})
+ %6 = nvvm.redux.sync fmax %value, %offset {nan = true}: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmax.abs.NaN(float %{{.*}}, i32 %{{.*}})
+ %7 = nvvm.redux.sync fmax %value, %offset {abs = true, nan = true}: f32 -> f32
+ llvm.return
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
fdbf42b
to
9e93205
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, it looks instruction is implemented fully.
I32:$mask_and_clamp)> { | ||
I32:$mask_and_clamp, | ||
DefaultValuedAttr<BoolAttr, "false">:$abs, | ||
DefaultValuedAttr<BoolAttr, "false">:$nan)> { | ||
string llvmBuilder = [{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add a doc-string for this op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in the latest revision. Thanks!
This change adds support for the f32 variants of the `redux.sync` instruction in the NVVM Dialect through the newly added intrinsics for the same.
9e93205
to
9f3bf77
Compare
This change adds support for the f32 variants of the
redux.sync
instruction in the NVVM Dialect through the newly added intrinsics for the same.