Skip to content

Commit f0d6c86

Browse files
committed
[mlir][gpu] Align reduction operations with vector combining kinds
The motivation for this change is explained in #72354. Before this change, we could not tell between signed/unsigned minimum/maximum and NaN treatment for floating point values. The mapping of old reduction operations to the new ones is as follows: * `min` --> `minsi` for ints, `minf` for floats * `max` --> `maxsi` for ints, `maxf` for floats New reduction kinds not represented in the old enum: `minui`, `maxui`, `minimumf`, `maximumf`. As a next step, I would like to have a common definition of combining kinds used by the `vector` and `gpu` dialects. Separately, the GPU to SPIR-V lowering does not yet properly handle zero and NaN values -- the behavior of floating point min/max group reductions is not specified by the SPIR-V spec. Issue: #72354
1 parent f4a4e2f commit f0d6c86

File tree

10 files changed

+313
-159
lines changed

10 files changed

+313
-159
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -868,25 +868,41 @@ def GPU_YieldOp : GPU_Op<"yield", [Pure, Terminator]>,
868868
}];
869869
}
870870

871-
// add, mul mirror the XLA ComparisonDirection enum.
871+
// These mirror the reduction combining kinds from the vector dialect.
872872
def GPU_AllReduceOpAdd : I32EnumAttrCase<"ADD", 0, "add">;
873-
def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 1, "and">;
874-
def GPU_AllReduceOpMax : I32EnumAttrCase<"MAX", 2, "max">;
875-
def GPU_AllReduceOpMin : I32EnumAttrCase<"MIN", 3, "min">;
876-
def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 4, "mul">;
877-
def GPU_AllReduceOpOr : I32EnumAttrCase<"OR", 5, "or">;
878-
def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 6, "xor">;
873+
def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 1, "mul">;
874+
def GPU_AllReduceOpMinUI : I32EnumAttrCase<"MINUI", 2, "minui">;
875+
def GPU_AllReduceOpMinSI : I32EnumAttrCase<"MINSI", 3, "minsi">;
876+
// Follows the `arith.minnumf` semantics.
877+
def GPU_AllReduceOpMinF : I32EnumAttrCase<"MINF", 4, "minf">;
878+
def GPU_AllReduceOpMaxUI : I32EnumAttrCase<"MAXUI", 5, "maxui">;
879+
def GPU_AllReduceOpMaxSI : I32EnumAttrCase<"MAXSI", 6, "maxsi">;
880+
// Follows the `arith.maxnumf` semantics.
881+
def GPU_AllReduceOpMaxF : I32EnumAttrCase<"MAXF", 7, "maxf">;
882+
def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 8, "and">;
883+
def GPU_AllReduceOpOr : I32EnumAttrCase<"OR", 9, "or">;
884+
def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 10, "xor">;
885+
// Follows the `arith.minimumf` semantics.
886+
def GPU_AllReduceOpMinimumF : I32EnumAttrCase<"MINIMUMF", 11, "minimumf">;
887+
// Follows the `arith.maximumf` semantics.
888+
def GPU_AllReduceOpMaximumF : I32EnumAttrCase<"MAXIMUMF", 12, "maximumf">;
879889

880890
def GPU_AllReduceOperation : I32EnumAttr<"AllReduceOperation",
881891
"built-in reduction operations supported by gpu.allreduce.",
882892
[
883893
GPU_AllReduceOpAdd,
884-
GPU_AllReduceOpAnd,
885-
GPU_AllReduceOpMax,
886-
GPU_AllReduceOpMin,
887894
GPU_AllReduceOpMul,
895+
GPU_AllReduceOpMinUI,
896+
GPU_AllReduceOpMinSI,
897+
GPU_AllReduceOpMinF,
898+
GPU_AllReduceOpMaxUI,
899+
GPU_AllReduceOpMaxSI,
900+
GPU_AllReduceOpMaxF,
901+
GPU_AllReduceOpAnd,
888902
GPU_AllReduceOpOr,
889-
GPU_AllReduceOpXor
903+
GPU_AllReduceOpXor,
904+
GPU_AllReduceOpMinimumF,
905+
GPU_AllReduceOpMaximumF
890906
]>{
891907
let genSpecializedAttr = 0;
892908
let cppNamespace = "::mlir::gpu";
@@ -918,8 +934,11 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
918934

919935
compute the sum of each work item's %0 value. The first version specifies
920936
the accumulation as operation, whereas the second version specifies the
921-
accumulation as code region. The accumulation operation must be one of:
922-
`add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
937+
accumulation as code region. The reduction operation must be one of:
938+
* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
939+
`or`, `xor`
940+
* Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
941+
`maximumf`
923942

924943
If `uniform` flag is set either none or all work items of a workgroup
925944
need to execute this op in convergence.
@@ -951,7 +970,12 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
951970
```
952971

953972
If `uniform` flag is set either none or all work items of a subgroup
954-
need to execute this op in convergence.
973+
need to execute this op in convergence. The reduction operation must be one
974+
of:
975+
* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
976+
`or`, `xor`
977+
* Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
978+
`maximumf`
955979
}];
956980
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
957981
(`uniform` $uniform^)? attr-dict

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,28 @@ convertReduxKind(gpu::AllReduceOperation mode) {
6565
switch (mode) {
6666
case gpu::AllReduceOperation::ADD:
6767
return NVVM::ReduxKind::ADD;
68+
case gpu::AllReduceOperation::MUL:
69+
return std::nullopt;
70+
case gpu::AllReduceOperation::MINSI:
71+
return NVVM::ReduxKind::MIN;
72+
case gpu::AllReduceOperation::MINUI:
73+
return std::nullopt;
74+
case gpu::AllReduceOperation::MINF:
75+
return NVVM::ReduxKind::MIN;
76+
case gpu::AllReduceOperation::MAXSI:
77+
return NVVM::ReduxKind::MAX;
78+
case gpu::AllReduceOperation::MAXUI:
79+
return std::nullopt;
80+
case gpu::AllReduceOperation::MAXF:
81+
return NVVM::ReduxKind::MAX;
6882
case gpu::AllReduceOperation::AND:
6983
return NVVM::ReduxKind::AND;
70-
case gpu::AllReduceOperation::MAX:
71-
return NVVM::ReduxKind::MAX;
72-
case gpu::AllReduceOperation::MIN:
73-
return NVVM::ReduxKind::MIN;
7484
case gpu::AllReduceOperation::OR:
7585
return NVVM::ReduxKind::OR;
7686
case gpu::AllReduceOperation::XOR:
7787
return NVVM::ReduxKind::XOR;
78-
case gpu::AllReduceOperation::MUL:
88+
case gpu::AllReduceOperation::MINIMUMF:
89+
case gpu::AllReduceOperation::MAXIMUMF:
7990
return std::nullopt;
8091
}
8192
return std::nullopt;

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -503,26 +503,52 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
503503
return std::nullopt;
504504
}
505505

506+
// TODO: The SPIR-V spec does not specify how -0.0 / +0.0 and NaN values are
507+
// handled in *FMin/*FMax reduction ops. We should double account for this not
508+
// being defined in this conversion.
509+
506510
using ReduceType = gpu::AllReduceOperation;
507-
namespace spv = spirv;
508511
const OpHandler handlers[] = {
509512
{ReduceType::ADD,
510-
&createGroupReduceOpImpl<spv::GroupIAddOp, spv::GroupNonUniformIAddOp>,
511-
&createGroupReduceOpImpl<spv::GroupFAddOp, spv::GroupNonUniformFAddOp>},
513+
&createGroupReduceOpImpl<spirv::GroupIAddOp,
514+
spirv::GroupNonUniformIAddOp>,
515+
&createGroupReduceOpImpl<spirv::GroupFAddOp,
516+
spirv::GroupNonUniformFAddOp>},
512517
{ReduceType::MUL,
513-
&createGroupReduceOpImpl<spv::GroupIMulKHROp,
514-
spv::GroupNonUniformIMulOp>,
515-
&createGroupReduceOpImpl<spv::GroupFMulKHROp,
516-
spv::GroupNonUniformFMulOp>},
517-
{ReduceType::MIN,
518-
&createGroupReduceOpImpl<spv::GroupSMinOp, spv::GroupNonUniformSMinOp>,
519-
&createGroupReduceOpImpl<spv::GroupFMinOp, spv::GroupNonUniformFMinOp>},
520-
{ReduceType::MAX,
521-
&createGroupReduceOpImpl<spv::GroupSMaxOp, spv::GroupNonUniformSMaxOp>,
522-
&createGroupReduceOpImpl<spv::GroupFMaxOp, spv::GroupNonUniformFMaxOp>},
523-
};
524-
525-
for (auto &handler : handlers)
518+
&createGroupReduceOpImpl<spirv::GroupIMulKHROp,
519+
spirv::GroupNonUniformIMulOp>,
520+
&createGroupReduceOpImpl<spirv::GroupFMulKHROp,
521+
spirv::GroupNonUniformFMulOp>},
522+
{ReduceType::MINUI,
523+
&createGroupReduceOpImpl<spirv::GroupUMinOp,
524+
spirv::GroupNonUniformUMinOp>,
525+
nullptr},
526+
{ReduceType::MINSI,
527+
&createGroupReduceOpImpl<spirv::GroupSMinOp,
528+
spirv::GroupNonUniformSMinOp>,
529+
nullptr},
530+
{ReduceType::MINF, nullptr,
531+
&createGroupReduceOpImpl<spirv::GroupFMinOp,
532+
spirv::GroupNonUniformFMinOp>},
533+
{ReduceType::MAXUI,
534+
&createGroupReduceOpImpl<spirv::GroupUMaxOp,
535+
spirv::GroupNonUniformUMaxOp>,
536+
nullptr},
537+
{ReduceType::MAXSI,
538+
&createGroupReduceOpImpl<spirv::GroupSMaxOp,
539+
spirv::GroupNonUniformSMaxOp>,
540+
nullptr},
541+
{ReduceType::MAXF, nullptr,
542+
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
543+
spirv::GroupNonUniformFMaxOp>},
544+
{ReduceType::MINIMUMF, nullptr,
545+
&createGroupReduceOpImpl<spirv::GroupFMinOp,
546+
spirv::GroupNonUniformFMinOp>},
547+
{ReduceType::MAXIMUMF, nullptr,
548+
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
549+
spirv::GroupNonUniformFMaxOp>}};
550+
551+
for (const OpHandler &handler : handlers)
526552
if (handler.type == opType)
527553
return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
528554

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
#include "mlir/IR/TypeUtilities.h"
2828
#include "mlir/Interfaces/FunctionImplementation.h"
2929
#include "mlir/Interfaces/SideEffectInterfaces.h"
30+
#include "mlir/Support/LogicalResult.h"
3031
#include "mlir/Transforms/InliningUtils.h"
32+
#include "llvm/ADT/STLExtras.h"
3133
#include "llvm/ADT/TypeSwitch.h"
3234
#include "llvm/Support/CommandLine.h"
3335
#include "llvm/Support/ErrorHandling.h"
@@ -485,12 +487,23 @@ static LogicalResult verifyAttributions(Operation *op,
485487
// AllReduceOp
486488
//===----------------------------------------------------------------------===//
487489

488-
static bool verifyReduceOpAndType(gpu::AllReduceOperation opName,
489-
Type resType) {
490-
return (opName != gpu::AllReduceOperation::AND &&
491-
opName != gpu::AllReduceOperation::OR &&
492-
opName != gpu::AllReduceOperation::XOR) ||
493-
llvm::isa<IntegerType>(resType);
490+
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
491+
Type resType) {
492+
using Kind = gpu::AllReduceOperation;
493+
if (llvm::is_contained(
494+
{Kind::MINF, Kind::MAXF, Kind::MINIMUMF, Kind::MAXIMUMF}, opName)) {
495+
if (!isa<FloatType>(resType))
496+
return failure();
497+
}
498+
499+
if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
500+
Kind::AND, Kind::OR, Kind::XOR},
501+
opName)) {
502+
if (!isa<IntegerType>(resType))
503+
return failure();
504+
}
505+
506+
return success();
494507
}
495508

496509
LogicalResult gpu::AllReduceOp::verifyRegions() {
@@ -517,12 +530,13 @@ LogicalResult gpu::AllReduceOp::verifyRegions() {
517530
return emitError("expected gpu.yield op in region");
518531
} else {
519532
gpu::AllReduceOperation opName = *getOp();
520-
if (!verifyReduceOpAndType(opName, getType())) {
521-
return emitError()
522-
<< '`' << gpu::stringifyAllReduceOperation(opName)
523-
<< "` accumulator is only compatible with Integer type";
533+
if (failed(verifyReduceOpAndType(opName, getType()))) {
534+
return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
535+
<< "` reduction operation is not compatible with type "
536+
<< getType();
524537
}
525538
}
539+
526540
return success();
527541
}
528542

@@ -573,9 +587,10 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
573587

574588
LogicalResult gpu::SubgroupReduceOp::verify() {
575589
gpu::AllReduceOperation opName = getOp();
576-
if (!verifyReduceOpAndType(opName, getType())) {
590+
if (failed(verifyReduceOpAndType(opName, getType()))) {
577591
return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
578-
<< "` accumulator is only compatible with Integer type";
592+
<< "` reduction operation is not compatible with type "
593+
<< getType();
579594
}
580595
return success();
581596
}

mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -214,54 +214,49 @@ struct GpuAllReduceRewriter {
214214

215215
/// Returns an accumulator factory that creates an op specified by opName.
216216
AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
217+
using Kind = gpu::AllReduceOperation;
217218
bool isFloatingPoint = isa<FloatType>(valueType);
218219
switch (opName) {
219-
case gpu::AllReduceOperation::ADD:
220+
case Kind::ADD:
220221
return isFloatingPoint ? getFactory<arith::AddFOp>()
221222
: getFactory<arith::AddIOp>();
222-
case gpu::AllReduceOperation::MUL:
223+
case Kind::MUL:
223224
return isFloatingPoint ? getFactory<arith::MulFOp>()
224225
: getFactory<arith::MulIOp>();
225-
case gpu::AllReduceOperation::AND:
226+
case Kind::MINSI:
227+
return getFactory<arith::MinSIOp>();
228+
case Kind::MINUI:
229+
return getFactory<arith::MinUIOp>();
230+
case Kind::MINF:
231+
return getFactory<arith::MinNumFOp>();
232+
case Kind::MAXSI:
233+
return getFactory<arith::MaxSIOp>();
234+
case Kind::MAXUI:
235+
return getFactory<arith::MaxUIOp>();
236+
case Kind::MAXF:
237+
return getFactory<arith::MaxNumFOp>();
238+
case Kind::AND:
226239
return getFactory<arith::AndIOp>();
227-
case gpu::AllReduceOperation::OR:
240+
case Kind::OR:
228241
return getFactory<arith::OrIOp>();
229-
case gpu::AllReduceOperation::XOR:
242+
case Kind::XOR:
230243
return getFactory<arith::XOrIOp>();
231-
case gpu::AllReduceOperation::MAX:
232-
return isFloatingPoint
233-
? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
234-
arith::CmpFPredicate::UGT>()
235-
: getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
236-
arith::CmpIPredicate::ugt>();
237-
case gpu::AllReduceOperation::MIN:
238-
return isFloatingPoint
239-
? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
240-
arith::CmpFPredicate::ULT>()
241-
: getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
242-
arith::CmpIPredicate::ult>();
244+
case Kind::MINIMUMF:
245+
return getFactory<arith::MinimumFOp>();
246+
case Kind::MAXIMUMF:
247+
return getFactory<arith::MaximumFOp>();
243248
}
244249
llvm_unreachable("unknown GPU AllReduceOperation");
245250
}
246251

247252
/// Returns an accumulator factory that creates an op of type T.
248253
template <typename T>
249254
AccumulatorFactory getFactory() {
250-
return [&](Value lhs, Value rhs) {
255+
return [this](Value lhs, Value rhs) {
251256
return create<T>(lhs.getType(), lhs, rhs);
252257
};
253258
}
254259

255-
/// Returns an accumulator for comparison such as min, max. T is the type
256-
/// of the compare op.
257-
template <typename T, typename PredicateEnum, PredicateEnum predicate>
258-
AccumulatorFactory getCmpFactory() const {
259-
return [&](Value lhs, Value rhs) {
260-
Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
261-
return rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
262-
};
263-
}
264-
265260
/// Creates an if-block skeleton and calls the two factories to generate the
266261
/// ops in the `then` and `else` block..
267262
///

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -582,22 +582,22 @@ gpu.module @test_module_30 {
582582
%result = gpu.subgroup_reduce add %arg0 uniform {} : (i32) -> (i32)
583583
gpu.return
584584
}
585-
// CHECK-LABEL: func @subgroup_reduce_and
586-
gpu.func @subgroup_reduce_and(%arg0 : i32) {
587-
// CHECK: nvvm.redux.sync and {{.*}}
588-
%result = gpu.subgroup_reduce and %arg0 uniform {} : (i32) -> (i32)
585+
// CHECK-LABEL: @subgroup_reduce_minsi
586+
gpu.func @subgroup_reduce_minsi(%arg0 : i32) {
587+
// CHECK: nvvm.redux.sync min {{.*}}
588+
%result = gpu.subgroup_reduce minsi %arg0 uniform {} : (i32) -> (i32)
589589
gpu.return
590590
}
591-
// CHECK-LABEL: @subgroup_reduce_max
592-
gpu.func @subgroup_reduce_max(%arg0 : i32) {
591+
// CHECK-LABEL: @subgroup_reduce_maxsi
592+
gpu.func @subgroup_reduce_maxsi(%arg0 : i32) {
593593
// CHECK: nvvm.redux.sync max {{.*}}
594-
%result = gpu.subgroup_reduce max %arg0 uniform {} : (i32) -> (i32)
594+
%result = gpu.subgroup_reduce maxsi %arg0 uniform {} : (i32) -> (i32)
595595
gpu.return
596596
}
597-
// CHECK-LABEL: @subgroup_reduce_min
598-
gpu.func @subgroup_reduce_min(%arg0 : i32) {
599-
// CHECK: nvvm.redux.sync min {{.*}}
600-
%result = gpu.subgroup_reduce min %arg0 uniform {} : (i32) -> (i32)
597+
// CHECK-LABEL: func @subgroup_reduce_and
598+
gpu.func @subgroup_reduce_and(%arg0 : i32) {
599+
// CHECK: nvvm.redux.sync and {{.*}}
600+
%result = gpu.subgroup_reduce and %arg0 uniform {} : (i32) -> (i32)
601601
gpu.return
602602
}
603603
// CHECK-LABEL: @subgroup_reduce_or

0 commit comments

Comments
 (0)