Skip to content

Commit 9f3bf77

Browse files
committed
[MLIR][NVVM] Add support for f32 in redux.sync Op
This change adds support for the f32 variants of the `redux.sync` instruction in the NVVM Dialect through the newly added intrinsics for the same.
1 parent 22f5268 commit 9f3bf77

File tree

4 files changed

+96
-5
lines changed

4 files changed

+96
-5
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,13 @@ def ReduxKindOr : I32EnumAttrCase<"OR", 5, "or">;
257257
def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">;
258258
def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">;
259259
def ReduxKindXor : I32EnumAttrCase<"XOR", 8, "xor">;
260+
def ReduxKindFmin : I32EnumAttrCase<"FMIN", 9, "fmin">;
261+
def ReduxKindFmax : I32EnumAttrCase<"FMAX", 10, "fmax">;
260262

261263
/// Enum attribute of the different kinds.
262264
def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
263265
[ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr,
264-
ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> {
266+
ReduxKindUmax, ReduxKindUmin, ReduxKindXor, ReduxKindFmin, ReduxKindFmax]> {
265267
let genSpecializedAttr = 0;
266268
let cppNamespace = "::mlir::NVVM";
267269
}
@@ -273,9 +275,24 @@ def NVVM_ReduxOp :
273275
Results<(outs LLVM_Type:$res)>,
274276
Arguments<(ins LLVM_Type:$val,
275277
ReduxKindAttr:$kind,
276-
I32:$mask_and_clamp)> {
278+
I32:$mask_and_clamp,
279+
DefaultValuedAttr<BoolAttr, "false">:$abs,
280+
DefaultValuedAttr<BoolAttr, "false">:$nan)> {
281+
let summary = "Redux Sync Op";
282+
let description = [{
283+
`redux.sync` performs a reduction operation `kind` of the 32 bit source
284+
register across all non-exited threads in the membermask.
285+
286+
The `abs` and `nan` attributes can be used in the case of f32 input type,
287+
where the `abs` attribute causes the absolute value of the input to be used
288+
in the reduction operation, and the `nan` attribute causes the reduction
289+
operation to return NaN if any of the inputs to participating threads are
290+
NaN.
291+
292+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync)
293+
}];
277294
string llvmBuilder = [{
278-
auto intId = getReduxIntrinsicId($_resultType, $kind);
295+
auto intId = getReduxIntrinsicId($_resultType, $kind, $abs, $nan);
279296
$res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp});
280297
}];
281298
let assemblyFormat = [{

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,17 @@ using namespace mlir;
2525
using namespace mlir::LLVM;
2626
using mlir::LLVM::detail::createIntrinsicCall;
2727

28+
#define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
29+
hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
30+
: llvm::Intrinsic::nvvm_redux_sync_f##op##abs
31+
32+
#define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
33+
hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
34+
2835
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
29-
NVVM::ReduxKind kind) {
30-
if (!resultType->isIntegerTy(32))
36+
NVVM::ReduxKind kind,
37+
bool hasAbs, bool hasNaN) {
38+
if (!(resultType->isIntegerTy(32) || resultType->isFloatTy()))
3139
llvm_unreachable("unsupported data type for redux");
3240

3341
switch (kind) {
@@ -47,6 +55,10 @@ static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
4755
return llvm::Intrinsic::nvvm_redux_sync_max;
4856
case NVVM::ReduxKind::MIN:
4957
return llvm::Intrinsic::nvvm_redux_sync_min;
58+
case NVVM::ReduxKind::FMIN:
59+
return GET_REDUX_F32_ID(min, hasAbs, hasNaN);
60+
case NVVM::ReduxKind::FMAX:
61+
return GET_REDUX_F32_ID(max, hasAbs, hasNaN);
5062
}
5163
llvm_unreachable("unknown redux kind");
5264
}

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,25 @@ llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
411411
llvm.return %r1 : i32
412412
}
413413

414+
llvm.func @redux_sync_f32(%value: f32, %offset: i32) -> f32 {
415+
// CHECK: nvvm.redux.sync fmin %{{.*}}
416+
%r1 = nvvm.redux.sync fmin %value, %offset: f32 -> f32
417+
// CHECK: nvvm.redux.sync fmin %{{.*}}
418+
%r2 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32
419+
// CHECK: nvvm.redux.sync fmin %{{.*}}
420+
%r3 = nvvm.redux.sync fmin %value, %offset {NaN = true}: f32 -> f32
421+
// CHECK: nvvm.redux.sync fmin %{{.*}}
422+
%r4 = nvvm.redux.sync fmin %value, %offset {abs = true, NaN = true}: f32 -> f32
423+
// CHECK: nvvm.redux.sync fmax %{{.*}}
424+
%r5 = nvvm.redux.sync fmax %value, %offset: f32 -> f32
425+
// CHECK: nvvm.redux.sync fmax %{{.*}}
426+
%r6 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32
427+
// CHECK: nvvm.redux.sync fmax %{{.*}}
428+
%r7 = nvvm.redux.sync fmax %value, %offset {NaN = true}: f32 -> f32
429+
// CHECK: nvvm.redux.sync fmax %{{.*}}
430+
%r8 = nvvm.redux.sync fmax %value, %offset {abs = true, NaN = true}: f32 -> f32
431+
llvm.return %r1 : f32
432+
}
414433

415434
// -----
416435

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,3 +780,46 @@ llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
780780
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
781781
llvm.return
782782
}
783+
784+
// -----
785+
// CHECK-LABEL: @nvvm_redux_sync
786+
llvm.func @nvvm_redux_sync(%value: i32, %offset: i32) {
787+
// CHECK: call i32 @llvm.nvvm.redux.sync.add(i32 %{{.*}}, i32 %{{.*}})
788+
%0 = nvvm.redux.sync add %value, %offset: i32 -> i32
789+
// CHECK: call i32 @llvm.nvvm.redux.sync.umax(i32 %{{.*}}, i32 %{{.*}})
790+
%1 = nvvm.redux.sync umax %value, %offset: i32 -> i32
791+
// CHECK: call i32 @llvm.nvvm.redux.sync.umin(i32 %{{.*}}, i32 %{{.*}})
792+
%2 = nvvm.redux.sync umin %value, %offset: i32 -> i32
793+
// CHECK: call i32 @llvm.nvvm.redux.sync.and(i32 %{{.*}}, i32 %{{.*}})
794+
%3 = nvvm.redux.sync and %value, %offset: i32 -> i32
795+
// CHECK: call i32 @llvm.nvvm.redux.sync.or(i32 %{{.*}}, i32 %{{.*}})
796+
%4 = nvvm.redux.sync or %value, %offset: i32 -> i32
797+
// CHECK: call i32 @llvm.nvvm.redux.sync.xor(i32 %{{.*}}, i32 %{{.*}})
798+
%5 = nvvm.redux.sync xor %value, %offset: i32 -> i32
799+
// CHECK: call i32 @llvm.nvvm.redux.sync.max(i32 %{{.*}}, i32 %{{.*}})
800+
%6 = nvvm.redux.sync max %value, %offset: i32 -> i32
801+
// CHECK: call i32 @llvm.nvvm.redux.sync.min(i32 %{{.*}}, i32 %{{.*}})
802+
%7 = nvvm.redux.sync min %value, %offset: i32 -> i32
803+
llvm.return
804+
}
805+
806+
// CHECK-LABEL: @nvvm_redux_sync_f32
807+
llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) {
808+
// CHECK: call float @llvm.nvvm.redux.sync.fmin(float %{{.*}}, i32 %{{.*}})
809+
%0 = nvvm.redux.sync fmin %value, %offset: f32 -> f32
810+
// CHECK: call float @llvm.nvvm.redux.sync.fmin.abs(float %{{.*}}, i32 %{{.*}})
811+
%1 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32
812+
// CHECK: call float @llvm.nvvm.redux.sync.fmin.NaN(float %{{.*}}, i32 %{{.*}})
813+
%2 = nvvm.redux.sync fmin %value, %offset {nan = true}: f32 -> f32
814+
// CHECK: call float @llvm.nvvm.redux.sync.fmin.abs.NaN(float %{{.*}}, i32 %{{.*}})
815+
%3 = nvvm.redux.sync fmin %value, %offset {abs = true, nan = true}: f32 -> f32
816+
// CHECK: call float @llvm.nvvm.redux.sync.fmax(float %{{.*}}, i32 %{{.*}})
817+
%4 = nvvm.redux.sync fmax %value, %offset: f32 -> f32
818+
// CHECK: call float @llvm.nvvm.redux.sync.fmax.abs(float %{{.*}}, i32 %{{.*}})
819+
%5 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32
820+
// CHECK: call float @llvm.nvvm.redux.sync.fmax.NaN(float %{{.*}}, i32 %{{.*}})
821+
%6 = nvvm.redux.sync fmax %value, %offset {nan = true}: f32 -> f32
822+
// CHECK: call float @llvm.nvvm.redux.sync.fmax.abs.NaN(float %{{.*}}, i32 %{{.*}})
823+
%7 = nvvm.redux.sync fmax %value, %offset {abs = true, nan = true}: f32 -> f32
824+
llvm.return
825+
}

0 commit comments

Comments
 (0)